From 9d4e6212fa8d434089d32bff1217f39919abe44d Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 4 Aug 2016 11:20:17 -0700 Subject: [SPARK-16802] [SQL] fix overflow in LongToUnsafeRowMap ## What changes were proposed in this pull request? This patch fix the overflow in LongToUnsafeRowMap when the range of key is very wide (the key is much much smaller then minKey, for example, key is Long.MinValue, minKey is > 0). ## How was this patch tested? Added regression test (also for SPARK-16740) Author: Davies Liu Closes #14464 from davies/fix_overflow. --- .../spark/sql/execution/joins/HashedRelation.scala | 16 +++++--- .../sql/execution/joins/HashedRelationSuite.scala | 45 ++++++++++++++++++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cf4454c033..08975733ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -459,9 +459,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { - val idx = (key - minKey).toInt - if (idx >= 0 && key <= maxKey && array(idx) > 0) { - return getRow(array(idx), resultRow) + if (key >= minKey && key <= maxKey) { + val value = array((key - minKey).toInt) + if (value > 0) { + return getRow(value, resultRow) + } } } else { var pos = firstSlot(key) @@ -497,9 +499,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { - val idx = (key - minKey).toInt - if (idx >=0 && key <= maxKey && array(idx) > 0) { - return valueIter(array(idx), resultRow) + if (key >= minKey && key <= maxKey) { + val value = array((key - minKey).toInt) + if (value > 0) { + return valueIter(value, resultRow) + } } } else { var pos = firstSlot(key) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 40864c80eb..1196f5ec7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -152,6 +152,51 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { } } + test("LongToUnsafeRowMap with very wide range") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false))) + + { + // SPARK-16740 + val keys = Seq(0L, Long.MaxValue, Long.MaxValue) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) === k) + } + map.free() + } + + + { + // SPARK-16802 + val keys = Seq(Long.MaxValue, Long.MaxValue - 10) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) === k) + } + assert(map.getValue(Long.MinValue, row) eq null) + map.free() + } + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() -- cgit v1.2.3