diff options
author | Davies Liu <davies@databricks.com> | 2015-08-03 04:23:26 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-08-03 04:23:26 -0700 |
commit | 191bf2689d127a9dd328b9cc517362fd51eaed3d (patch) | |
tree | 4fcc6dc0c9003da69a8397cd469ea14c9d255979 /sql | |
parent | 137f47865df6e98ab70ae5ba30dc4d441fb41166 (diff) | |
download | spark-191bf2689d127a9dd328b9cc517362fd51eaed3d.tar.gz spark-191bf2689d127a9dd328b9cc517362fd51eaed3d.tar.bz2 spark-191bf2689d127a9dd328b9cc517362fd51eaed3d.zip |
[SPARK-9518] [SQL] cleanup generated UnsafeRowJoiner and fix bug
Currently, when copy the bitsets, we didn't consider that the row1 may not sit in the beginning of byte array.
cc rxin
Author: Davies Liu <davies@databricks.com>
Closes #7892 from davies/clean_join and squashes the following commits:
14cce9e [Davies Liu] cleanup generated UnsafeRowJoiner and fix bug
Diffstat (limited to 'sql')
2 files changed, 37 insertions, 72 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 645eb48d5a..5f8a6f8871 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -40,10 +40,6 @@ abstract class UnsafeRowJoiner { */ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), UnsafeRowJoiner] { - def dump(word: Long): String = { - Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" }.reverse.mkString - } - override protected def create(in: (StructType, StructType)): UnsafeRowJoiner = { create(in._1, in._2) } @@ -56,76 +52,45 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U } def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = { - val ctx = newCodeGenContext() val offset = PlatformDependent.BYTE_ARRAY_OFFSET + val getLong = "PlatformDependent.UNSAFE.getLong" + val putLong = "PlatformDependent.UNSAFE.putLong" val bitset1Words = (schema1.size + 63) / 64 val bitset2Words = (schema2.size + 63) / 64 val outputBitsetWords = (schema1.size + schema2.size + 63) / 64 val bitset1Remainder = schema1.size % 64 - val bitset2Remainder = schema2.size % 64 // The number of words we can reduce when we concat two rows together. // The only reduction comes from merging the bitset portion of the two rows, saving 1 word. val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords - // --------------------- copy bitset from row 1 ----------------------- // - val copyBitset1 = Seq.tabulate(bitset1Words) { i => - s""" - |PlatformDependent.UNSAFE.putLong(buf, ${offset + i * 8}, - | PlatformDependent.UNSAFE.getLong(obj1, ${offset + i * 8})); - """.stripMargin - }.mkString - - - // --------------------- copy bitset from row 2 ----------------------- // - var copyBitset2 = "" - if (bitset1Remainder == 0) { - copyBitset2 += Seq.tabulate(bitset2Words) { i => - s""" - |PlatformDependent.UNSAFE.putLong(buf, ${offset + (bitset1Words + i) * 8}, - | PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8})); - """.stripMargin - }.mkString - } else { - copyBitset2 = Seq.tabulate(bitset2Words) { i => - s""" - |long bs2w$i = PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8}); - |long bs2w${i}p1 = (bs2w$i << $bitset1Remainder) & ~((1L << $bitset1Remainder) - 1); - |long bs2w${i}p2 = (bs2w$i >>> ${64 - bitset1Remainder}); - """.stripMargin - }.mkString - - copyBitset2 += Seq.tabulate(bitset2Words) { i => - val currentOffset = offset + (bitset1Words + i - 1) * 8 - if (i == 0) { - if (bitset1Words > 0) { - s""" - |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, - | bs2w${i}p1 | PlatformDependent.UNSAFE.getLong(obj1, $currentOffset)); - """.stripMargin - } else { - s""" - |PlatformDependent.UNSAFE.putLong(buf, $currentOffset + 8, bs2w${i}p1); - """.stripMargin - } + // --------------------- copy bitset from row 1 and row 2 --------------------------- // + val copyBitset = Seq.tabulate(outputBitsetWords) { i => + val bits = if (bitset1Remainder > 0) { + if (i < bitset1Words - 1) { + s"$getLong(obj1, offset1 + ${i * 8})" + } else if (i == bitset1Words - 1) { + // combine last work of bitset1 and first word of bitset2 + s"$getLong(obj1, offset1 + ${i * 8}) | ($getLong(obj2, offset2) << $bitset1Remainder)" + } else if (i - bitset1Words < bitset2Words - 1) { + // combine next two words of bitset2 + s"($getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder))" + + s"| ($getLong(obj2, offset2 + ${(i - bitset1Words + 1) * 8}) << $bitset1Remainder)" + } else { + // last word of bitset2 + s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder)" + } + } else { + // they are aligned by word + if (i < bitset1Words) { + s"$getLong(obj1, offset1 + ${i * 8})" } else { - s""" - |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, bs2w${i}p1 | bs2w${i - 1}p2); - """.stripMargin + s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})" } - }.mkString("\n") - - if (bitset2Words > 0 && - (bitset2Remainder == 0 || bitset2Remainder > (64 - bitset1Remainder))) { - val lastWord = bitset2Words - 1 - copyBitset2 += - s""" - |PlatformDependent.UNSAFE.putLong(buf, ${offset + (outputBitsetWords - 1) * 8}, - | bs2w${lastWord}p2); - """.stripMargin } - } + s"$putLong(buf, ${offset + i * 8}, $bits);" + }.mkString("\n") // --------------------- copy fixed length portion from row 1 ----------------------- // var cursor = offset + outputBitsetWords * 8 @@ -149,10 +114,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U cursor += schema2.size * 8 // --------------------- copy variable length portion from row 1 ----------------------- // + val numBytesBitsetAndFixedRow1 = (bitset1Words + schema1.size) * 8 val copyVariableLengthRow1 = s""" |// Copy variable length data for row1 - |long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8}; - |long numBytesVariableRow1 = row1.getSizeInBytes() - numBytesBitsetAndFixedRow1; + |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1; |PlatformDependent.copyMemory( | obj1, offset1 + ${(bitset1Words + schema1.size) * 8}, | buf, $cursor, @@ -160,10 +125,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U """.stripMargin // --------------------- copy variable length portion from row 2 ----------------------- // + val numBytesBitsetAndFixedRow2 = (bitset2Words + schema2.size) * 8 val copyVariableLengthRow2 = s""" |// Copy variable length data for row2 - |long numBytesBitsetAndFixedRow2 = ${(bitset2Words + schema2.size) * 8}; - |long numBytesVariableRow2 = row2.getSizeInBytes() - numBytesBitsetAndFixedRow2; + |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2; |PlatformDependent.copyMemory( | obj2, offset2 + ${(bitset2Words + schema2.size) * 8}, | buf, $cursor + numBytesVariableRow1, @@ -183,12 +148,11 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U if (i < schema1.size) { s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L" } else { - s"${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1" + s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)" } val cursor = offset + outputBitsetWords * 8 + i * 8 s""" - |PlatformDependent.UNSAFE.putLong(buf, $cursor, - | PlatformDependent.UNSAFE.getLong(buf, $cursor) + ($shift << 32)); + |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32)); """.stripMargin } }.mkString @@ -217,8 +181,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | final Object obj2 = row2.getBaseObject(); | final long offset2 = row2.getBaseOffset(); | - | $copyBitset1 - | $copyBitset2 + | $copyBitset | $copyFixedLengthRow1 | $copyFixedLengthRow2 | $copyVariableLengthRow1 @@ -233,7 +196,6 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U """.stripMargin logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") - // println(CodeFormatter.format(code)) val c = compile(code) c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index 76d9d991ed..718a2acc82 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -22,6 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.PlatformDependent /** * A test suite for the bitset portion of the row concatenation. @@ -91,8 +92,9 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { private def createUnsafeRow(numFields: Int): UnsafeRow = { val row = new UnsafeRow val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8 - val buf = new Array[Byte](sizeInBytes) - row.pointTo(buf, numFields, sizeInBytes) + val offset = numFields * 8 + val buf = new Array[Byte](sizeInBytes + offset) + row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) row } @@ -133,6 +135,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { |input1: ${set1.mkString} |input2: ${set2.mkString} |output: ${out.mkString} + |expect: ${set1.mkString}${set2.mkString} """.stripMargin } |