aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-09-06 10:46:31 -0700
committerDavies Liu <davies.liu@gmail.com>2016-09-06 10:46:31 -0700
commitf7e26d788757f917b32749856bb29feb7b4c2987 (patch)
treeb3efe1ff2f73e345c1e4c955239c648e0b9739ca /sql
parentbc2767df2666ff615e7f44e980555afab06dd8a3 (diff)
downloadspark-f7e26d788757f917b32749856bb29feb7b4c2987.tar.gz
spark-f7e26d788757f917b32749856bb29feb7b4c2987.tar.bz2
spark-f7e26d788757f917b32749856bb29feb7b4c2987.zip
[SPARK-16922] [SPARK-17211] [SQL] make the address of values portable in LongToUnsafeRowMap
## What changes were proposed in this pull request? In LongToUnsafeRowMap, we use offset of a value as pointer, stored in a array also in the page for chained values. The offset is not portable, because Platform.LONG_ARRAY_OFFSET will be different with different JVM Heap size, then the deserialized LongToUnsafeRowMap will be corrupt. This PR will change to use portable address (without Platform.LONG_ARRAY_OFFSET). ## How was this patch tested? Added a test case with random generated keys, to improve the coverage. But this test is not a regression test, that could require a Spark cluster that have at least 32G heap in driver or executor. Author: Davies Liu <davies@databricks.com> Closes #14927 from davies/longmap.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala56
2 files changed, 75 insertions, 8 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 08975733ff..8821c0dea9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -447,10 +447,20 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
private def nextSlot(pos: Int): Int = (pos + 2) & mask
+ private[this] def toAddress(offset: Long, size: Int): Long = {
+ ((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size
+ }
+
+ private[this] def toOffset(address: Long): Long = {
+ (address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET
+ }
+
+ private[this] def toSize(address: Long): Int = {
+ (address & SIZE_MASK).toInt
+ }
+
private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
- val offset = address >>> SIZE_BITS
- val size = address & SIZE_MASK
- resultRow.pointTo(page, offset, size.toInt)
+ resultRow.pointTo(page, toOffset(address), toSize(address))
resultRow
}
@@ -485,9 +495,9 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
var addr = address
override def hasNext: Boolean = addr != 0
override def next(): UnsafeRow = {
- val offset = addr >>> SIZE_BITS
- val size = addr & SIZE_MASK
- resultRow.pointTo(page, offset, size.toInt)
+ val offset = toOffset(addr)
+ val size = toSize(addr)
+ resultRow.pointTo(page, offset, size)
addr = Platform.getLong(page, offset + size)
resultRow
}
@@ -554,7 +564,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
Platform.putLong(page, cursor, 0)
cursor += 8
numValues += 1
- updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes)
+ updateIndex(key, toAddress(offset, row.getSizeInBytes))
}
/**
@@ -562,6 +572,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
private def updateIndex(key: Long, address: Long): Unit = {
var pos = firstSlot(key)
+ assert(numKeys < array.length / 2)
while (array(pos) != key && array(pos + 1) != 0) {
pos = nextSlot(pos)
}
@@ -582,7 +593,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
}
} else {
// there are some values for this key, put the address in the front of them.
- val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK)
+ val pointer = toOffset(address) + toSize(address)
Platform.putLong(page, pointer, array(pos + 1))
array(pos + 1) = address
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 1196f5ec7b..ede63fea96 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.joins
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
+import scala.util.Random
+
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.serializer.KryoSerializer
@@ -197,6 +199,60 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
}
}
+ test("LongToUnsafeRowMap with random keys") {
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
+
+ val N = 1000000
+ val rand = new Random
+ val keys = (0 to N).map(x => rand.nextLong()).toArray
+
+ val map = new LongToUnsafeRowMap(taskMemoryManager, 10)
+ keys.foreach { k =>
+ map.append(k, unsafeProj(InternalRow(k)))
+ }
+ map.optimize()
+
+ val os = new ByteArrayOutputStream()
+ val out = new ObjectOutputStream(os)
+ map.writeExternal(out)
+ out.flush()
+ val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
+ val map2 = new LongToUnsafeRowMap(taskMemoryManager, 1)
+ map2.readExternal(in)
+
+ val row = unsafeProj(InternalRow(0L)).copy()
+ keys.foreach { k =>
+ val r = map2.get(k, row)
+ assert(r.hasNext)
+ var c = 0
+ while (r.hasNext) {
+ val rr = r.next()
+ assert(rr.getLong(0) === k)
+ c += 1
+ }
+ }
+ var i = 0
+ while (i < N * 10) {
+ val k = rand.nextLong()
+ val r = map2.get(k, row)
+ if (r != null) {
+ assert(r.hasNext)
+ while (r.hasNext) {
+ assert(r.next().getLong(0) === k)
+ }
+ }
+ i += 1
+ }
+ map.free()
+ }
+
test("Spark-14521") {
val ser = new KryoSerializer(
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()