diff options
author | Davies Liu <davies@databricks.com> | 2016-08-04 11:20:17 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-08-04 11:20:17 -0700 |
commit | 9d4e6212fa8d434089d32bff1217f39919abe44d (patch) | |
tree | 9dbfdd0832bd98d7b497462acf9c5a916d5bbab7 | |
parent | 9d7a47406ed538f0005cdc7a62bc6e6f20634815 (diff) | |
download | spark-9d4e6212fa8d434089d32bff1217f39919abe44d.tar.gz spark-9d4e6212fa8d434089d32bff1217f39919abe44d.tar.bz2 spark-9d4e6212fa8d434089d32bff1217f39919abe44d.zip |
[SPARK-16802] [SQL] fix overflow in LongToUnsafeRowMap
## What changes were proposed in this pull request?
This patch fix the overflow in LongToUnsafeRowMap when the range of key is very wide (the key is much much smaller then minKey, for example, key is Long.MinValue, minKey is > 0).
## How was this patch tested?
Added regression test (also for SPARK-16740)
Author: Davies Liu <davies@databricks.com>
Closes #14464 from davies/fix_overflow.
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala | 16 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala | 45 |
2 files changed, 55 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cf4454c033..08975733ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -459,9 +459,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { - val idx = (key - minKey).toInt - if (idx >= 0 && key <= maxKey && array(idx) > 0) { - return getRow(array(idx), resultRow) + if (key >= minKey && key <= maxKey) { + val value = array((key - minKey).toInt) + if (value > 0) { + return getRow(value, resultRow) + } } } else { var pos = firstSlot(key) @@ -497,9 +499,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { - val idx = (key - minKey).toInt - if (idx >=0 && key <= maxKey && array(idx) > 0) { - return valueIter(array(idx), resultRow) + if (key >= minKey && key <= maxKey) { + val value = array((key - minKey).toInt) + if (value > 0) { + return valueIter(value, resultRow) + } } } else { var pos = firstSlot(key) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 40864c80eb..1196f5ec7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -152,6 +152,51 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { } } + test("LongToUnsafeRowMap with very wide range") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false))) + + { + // SPARK-16740 + val keys = Seq(0L, Long.MaxValue, Long.MaxValue) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) === k) + } + map.free() + } + + + { + // SPARK-16802 + val keys = Seq(Long.MaxValue, Long.MaxValue - 10) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) === k) + } + assert(map.getValue(Long.MinValue, row) eq null) + map.free() + } + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() |