diff options
author | Davies Liu <davies@databricks.com> | 2015-08-17 23:27:55 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-08-17 23:27:55 -0700 |
commit | 5af3838d2e59ed83766f85634e26918baa53819f (patch) | |
tree | d75339222a3d0270b42bd5a33591429491fda3e9 | |
parent | a0910315dae88b033e38a1de07f39ca21f6552ad (diff) | |
download | spark-5af3838d2e59ed83766f85634e26918baa53819f.tar.gz spark-5af3838d2e59ed83766f85634e26918baa53819f.tar.bz2 spark-5af3838d2e59ed83766f85634e26918baa53819f.zip |
[SPARK-10038] [SQL] fix bug in generated unsafe projection when there is binary in ArrayData
The type for array of array in Java is slightly different than array of others.
cc cloud-fan
Author: Davies Liu <davies@databricks.com>
Closes #8250 from davies/array_binary.
2 files changed, 29 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index b2fb913850..b570fe86db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -224,7 +224,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro // go through the input array to calculate how many bytes we need. val calculateNumBytes = elementType match { - case _ if (ctx.isPrimitiveType(elementType)) => + case _ if ctx.isPrimitiveType(elementType) => // Should we do word align? val elementSize = elementType.defaultSize s""" @@ -237,6 +237,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => val writer = getWriter(elementType) val elementSize = s"$writer.getSize($elements[$index])" + // TODO(davies): avoid the copy val unsafeType = elementType match { case _: StructType => "UnsafeRow" case _: ArrayType => "UnsafeArrayData" @@ -249,8 +250,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => "" } + val newElements = if (elementType == BinaryType) { + s"new byte[$numElements][]" + } else { + s"new $unsafeType[$numElements]" + } s""" - final $unsafeType[] $elements = new $unsafeType[$numElements]; + final $unsafeType[] $elements = $newElements; for (int $index = 0; $index < $numElements; $index++) { ${convertedElement.code} if (!${convertedElement.isNull}) { @@ -262,7 +268,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro } val writeElement = elementType match { - case _ if (ctx.isPrimitiveType(elementType)) => + case _ if ctx.isPrimitiveType(elementType) => // Should we do word align? val elementSize = elementType.defaultSize s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index 8c7ee8720f..098944a9f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, IntegerType, StructField, StructType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** @@ -79,4 +79,23 @@ class GeneratedProjectionSuite extends SparkFunSuite { val row2 = mutableProj(result) assert(result === row2) } + + test("generated unsafe projection with array of binary") { + val row = InternalRow( + Array[Byte](1, 2), + new GenericArrayData(Array(Array[Byte](1, 2), null, Array[Byte](3, 4)))) + val fields = (BinaryType :: ArrayType(BinaryType) :: Nil).toArray[DataType] + + val unsafeProj = UnsafeProjection.create(fields) + val unsafeRow: UnsafeRow = unsafeProj(row) + assert(java.util.Arrays.equals(unsafeRow.getBinary(0), Array[Byte](1, 2))) + assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(0), Array[Byte](1, 2))) + assert(unsafeRow.getArray(1).isNullAt(1)) + assert(unsafeRow.getArray(1).getBinary(1) === null) + assert(java.util.Arrays.equals(unsafeRow.getArray(1).getBinary(2), Array[Byte](3, 4))) + + val safeProj = FromUnsafeProjection(fields) + val row2 = safeProj(unsafeRow) + assert(row2 === row) + } } |