aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-08-03 04:23:26 -0700
committerReynold Xin <rxin@databricks.com>2015-08-03 04:23:26 -0700
commit191bf2689d127a9dd328b9cc517362fd51eaed3d (patch)
tree4fcc6dc0c9003da69a8397cd469ea14c9d255979 /sql
parent137f47865df6e98ab70ae5ba30dc4d441fb41166 (diff)
downloadspark-191bf2689d127a9dd328b9cc517362fd51eaed3d.tar.gz
spark-191bf2689d127a9dd328b9cc517362fd51eaed3d.tar.bz2
spark-191bf2689d127a9dd328b9cc517362fd51eaed3d.zip
[SPARK-9518] [SQL] cleanup generated UnsafeRowJoiner and fix bug
Currently, when copy the bitsets, we didn't consider that the row1 may not sit in the beginning of byte array. cc rxin Author: Davies Liu <davies@databricks.com> Closes #7892 from davies/clean_join and squashes the following commits: 14cce9e [Davies Liu] cleanup generated UnsafeRowJoiner and fix bug
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala102
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala7
2 files changed, 37 insertions, 72 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index 645eb48d5a..5f8a6f8871 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -40,10 +40,6 @@ abstract class UnsafeRowJoiner {
*/
object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), UnsafeRowJoiner] {
- def dump(word: Long): String = {
- Seq.tabulate(64) { i => if ((word >> i) % 2 == 0) "0" else "1" }.reverse.mkString
- }
-
override protected def create(in: (StructType, StructType)): UnsafeRowJoiner = {
create(in._1, in._2)
}
@@ -56,76 +52,45 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
}
def create(schema1: StructType, schema2: StructType): UnsafeRowJoiner = {
- val ctx = newCodeGenContext()
val offset = PlatformDependent.BYTE_ARRAY_OFFSET
+ val getLong = "PlatformDependent.UNSAFE.getLong"
+ val putLong = "PlatformDependent.UNSAFE.putLong"
val bitset1Words = (schema1.size + 63) / 64
val bitset2Words = (schema2.size + 63) / 64
val outputBitsetWords = (schema1.size + schema2.size + 63) / 64
val bitset1Remainder = schema1.size % 64
- val bitset2Remainder = schema2.size % 64
// The number of words we can reduce when we concat two rows together.
// The only reduction comes from merging the bitset portion of the two rows, saving 1 word.
val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords
- // --------------------- copy bitset from row 1 ----------------------- //
- val copyBitset1 = Seq.tabulate(bitset1Words) { i =>
- s"""
- |PlatformDependent.UNSAFE.putLong(buf, ${offset + i * 8},
- | PlatformDependent.UNSAFE.getLong(obj1, ${offset + i * 8}));
- """.stripMargin
- }.mkString
-
-
- // --------------------- copy bitset from row 2 ----------------------- //
- var copyBitset2 = ""
- if (bitset1Remainder == 0) {
- copyBitset2 += Seq.tabulate(bitset2Words) { i =>
- s"""
- |PlatformDependent.UNSAFE.putLong(buf, ${offset + (bitset1Words + i) * 8},
- | PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8}));
- """.stripMargin
- }.mkString
- } else {
- copyBitset2 = Seq.tabulate(bitset2Words) { i =>
- s"""
- |long bs2w$i = PlatformDependent.UNSAFE.getLong(obj2, ${offset + i * 8});
- |long bs2w${i}p1 = (bs2w$i << $bitset1Remainder) & ~((1L << $bitset1Remainder) - 1);
- |long bs2w${i}p2 = (bs2w$i >>> ${64 - bitset1Remainder});
- """.stripMargin
- }.mkString
-
- copyBitset2 += Seq.tabulate(bitset2Words) { i =>
- val currentOffset = offset + (bitset1Words + i - 1) * 8
- if (i == 0) {
- if (bitset1Words > 0) {
- s"""
- |PlatformDependent.UNSAFE.putLong(buf, $currentOffset,
- | bs2w${i}p1 | PlatformDependent.UNSAFE.getLong(obj1, $currentOffset));
- """.stripMargin
- } else {
- s"""
- |PlatformDependent.UNSAFE.putLong(buf, $currentOffset + 8, bs2w${i}p1);
- """.stripMargin
- }
+ // --------------------- copy bitset from row 1 and row 2 --------------------------- //
+ val copyBitset = Seq.tabulate(outputBitsetWords) { i =>
+ val bits = if (bitset1Remainder > 0) {
+ if (i < bitset1Words - 1) {
+ s"$getLong(obj1, offset1 + ${i * 8})"
+ } else if (i == bitset1Words - 1) {
+ // combine last work of bitset1 and first word of bitset2
+ s"$getLong(obj1, offset1 + ${i * 8}) | ($getLong(obj2, offset2) << $bitset1Remainder)"
+ } else if (i - bitset1Words < bitset2Words - 1) {
+ // combine next two words of bitset2
+ s"($getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder))" +
+ s"| ($getLong(obj2, offset2 + ${(i - bitset1Words + 1) * 8}) << $bitset1Remainder)"
+ } else {
+ // last word of bitset2
+ s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8}) >>> (64 - $bitset1Remainder)"
+ }
+ } else {
+ // they are aligned by word
+ if (i < bitset1Words) {
+ s"$getLong(obj1, offset1 + ${i * 8})"
} else {
- s"""
- |PlatformDependent.UNSAFE.putLong(buf, $currentOffset, bs2w${i}p1 | bs2w${i - 1}p2);
- """.stripMargin
+ s"$getLong(obj2, offset2 + ${(i - bitset1Words) * 8})"
}
- }.mkString("\n")
-
- if (bitset2Words > 0 &&
- (bitset2Remainder == 0 || bitset2Remainder > (64 - bitset1Remainder))) {
- val lastWord = bitset2Words - 1
- copyBitset2 +=
- s"""
- |PlatformDependent.UNSAFE.putLong(buf, ${offset + (outputBitsetWords - 1) * 8},
- | bs2w${lastWord}p2);
- """.stripMargin
}
- }
+ s"$putLong(buf, ${offset + i * 8}, $bits);"
+ }.mkString("\n")
// --------------------- copy fixed length portion from row 1 ----------------------- //
var cursor = offset + outputBitsetWords * 8
@@ -149,10 +114,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
cursor += schema2.size * 8
// --------------------- copy variable length portion from row 1 ----------------------- //
+ val numBytesBitsetAndFixedRow1 = (bitset1Words + schema1.size) * 8
val copyVariableLengthRow1 = s"""
|// Copy variable length data for row1
- |long numBytesBitsetAndFixedRow1 = ${(bitset1Words + schema1.size) * 8};
- |long numBytesVariableRow1 = row1.getSizeInBytes() - numBytesBitsetAndFixedRow1;
+ |long numBytesVariableRow1 = row1.getSizeInBytes() - $numBytesBitsetAndFixedRow1;
|PlatformDependent.copyMemory(
| obj1, offset1 + ${(bitset1Words + schema1.size) * 8},
| buf, $cursor,
@@ -160,10 +125,10 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
""".stripMargin
// --------------------- copy variable length portion from row 2 ----------------------- //
+ val numBytesBitsetAndFixedRow2 = (bitset2Words + schema2.size) * 8
val copyVariableLengthRow2 = s"""
|// Copy variable length data for row2
- |long numBytesBitsetAndFixedRow2 = ${(bitset2Words + schema2.size) * 8};
- |long numBytesVariableRow2 = row2.getSizeInBytes() - numBytesBitsetAndFixedRow2;
+ |long numBytesVariableRow2 = row2.getSizeInBytes() - $numBytesBitsetAndFixedRow2;
|PlatformDependent.copyMemory(
| obj2, offset2 + ${(bitset2Words + schema2.size) * 8},
| buf, $cursor + numBytesVariableRow1,
@@ -183,12 +148,11 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
if (i < schema1.size) {
s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
} else {
- s"${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1"
+ s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)"
}
val cursor = offset + outputBitsetWords * 8 + i * 8
s"""
- |PlatformDependent.UNSAFE.putLong(buf, $cursor,
- | PlatformDependent.UNSAFE.getLong(buf, $cursor) + ($shift << 32));
+ |$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));
""".stripMargin
}
}.mkString
@@ -217,8 +181,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
| final Object obj2 = row2.getBaseObject();
| final long offset2 = row2.getBaseOffset();
|
- | $copyBitset1
- | $copyBitset2
+ | $copyBitset
| $copyFixedLengthRow1
| $copyFixedLengthRow2
| $copyVariableLengthRow1
@@ -233,7 +196,6 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U
""".stripMargin
logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}")
- // println(CodeFormatter.format(code))
val c = compile(code)
c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner]
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
index 76d9d991ed..718a2acc82 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala
@@ -22,6 +22,7 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.PlatformDependent
/**
* A test suite for the bitset portion of the row concatenation.
@@ -91,8 +92,9 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {
private def createUnsafeRow(numFields: Int): UnsafeRow = {
val row = new UnsafeRow
val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8
- val buf = new Array[Byte](sizeInBytes)
- row.pointTo(buf, numFields, sizeInBytes)
+ val offset = numFields * 8
+ val buf = new Array[Byte](sizeInBytes + offset)
+ row.pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes)
row
}
@@ -133,6 +135,7 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite {
|input1: ${set1.mkString}
|input2: ${set2.mkString}
|output: ${out.mkString}
+ |expect: ${set1.mkString}${set2.mkString}
""".stripMargin
}