aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala84
1 files changed, 24 insertions, 60 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
index 03c5f449bf..55562facf9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala
@@ -206,11 +206,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();")
val buffer = ctx.freshName("buffer")
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
+ val tmpBuffer = ctx.freshName("tmpBuffer")
val outputIsNull = ctx.freshName("isNull")
val numElements = ctx.freshName("numElements")
val fixedSize = ctx.freshName("fixedSize")
val numBytes = ctx.freshName("numBytes")
- val elements = ctx.freshName("elements")
val cursor = ctx.freshName("cursor")
val index = ctx.freshName("index")
val elementName = ctx.freshName("elementName")
@@ -224,57 +224,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val convertedElement = createConvertCode(ctx, element, elementType)
- // go through the input array to calculate how many bytes we need.
- val calculateNumBytes = elementType match {
- case _ if ctx.isPrimitiveType(elementType) =>
- // Should we do word align?
- val elementSize = elementType.defaultSize
- s"""
- $numBytes += $elementSize * $numElements;
- """
- case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
- s"""
- $numBytes += 8 * $numElements;
- """
- case _ =>
- val writer = getWriter(elementType)
- val elementSize = s"$writer.getSize($elements[$index])"
- // TODO(davies): avoid the copy
- val unsafeType = elementType match {
- case _: StructType => "UnsafeRow"
- case _: ArrayType => "UnsafeArrayData"
- case _: MapType => "UnsafeMapData"
- case _ => ctx.javaType(elementType)
- }
- val copy = elementType match {
- // We reuse the buffer during conversion, need copy it before process next element.
- case _: StructType | _: ArrayType | _: MapType => ".copy()"
- case _ => ""
- }
-
- val newElements = if (elementType == BinaryType) {
- s"new byte[$numElements][]"
- } else {
- s"new $unsafeType[$numElements]"
- }
- s"""
- final $unsafeType[] $elements = $newElements;
- for (int $index = 0; $index < $numElements; $index++) {
- ${convertedElement.code}
- if (!${convertedElement.isNull}) {
- $elements[$index] = ${convertedElement.primitive}$copy;
- $numBytes += $elementSize;
- }
- }
- """
- }
-
val writeElement = elementType match {
case _ if ctx.isPrimitiveType(elementType) =>
// Should we do word align?
val elementSize = elementType.defaultSize
s"""
- ${convertedElement.code}
Platform.put${ctx.primitiveTypeName(elementType)}(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
@@ -283,7 +237,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
"""
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
- ${convertedElement.code}
Platform.putLong(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
@@ -296,15 +249,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$cursor += $writer.write(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
- $elements[$index]);
+ ${convertedElement.primitive});
"""
}
- val checkNull = elementType match {
- case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}"
- case t: DecimalType => s"$elements[$index] == null" +
- s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})"
- case _ => s"$elements[$index] == null"
+ val checkNull = convertedElement.isNull + (elementType match {
+ case t: DecimalType =>
+ s" || !${convertedElement.primitive}.changePrecision(${t.precision}, ${t.scale})"
+ case _ => ""
+ })
+
+ val elementSize = elementType match {
+ // Should we do word align for primitive types?
+ case _ if ctx.isPrimitiveType(elementType) => elementType.defaultSize.toString
+ case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => "8"
+ case _ =>
+ val writer = getWriter(elementType)
+ s"$writer.getSize(${convertedElement.primitive})"
}
val code = s"""
@@ -318,18 +279,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
final int $fixedSize = 4 * $numElements;
int $numBytes = $fixedSize;
- $calculateNumBytes
-
- if ($numBytes > $buffer.length) {
- $buffer = new byte[$numBytes];
- }
-
int $cursor = $fixedSize;
for (int $index = 0; $index < $numElements; $index++) {
+ ${convertedElement.code}
if ($checkNull) {
// If element is null, write the negative value address into offset region.
Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor);
} else {
+ $numBytes += $elementSize;
+ if ($buffer.length < $numBytes) {
+ // This will not happen frequently, because the buffer is re-used.
+ byte[] $tmpBuffer = new byte[$numBytes * 2];
+ Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET,
+ $tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
+ $buffer = $tmpBuffer;
+ }
Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor);
$writeElement
}