diff options
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala | 16 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala | 45 |
2 files changed, 55 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index cf4454c033..08975733ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -459,9 +459,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = { if (isDense) { - val idx = (key - minKey).toInt - if (idx >= 0 && key <= maxKey && array(idx) > 0) { - return getRow(array(idx), resultRow) + if (key >= minKey && key <= maxKey) { + val value = array((key - minKey).toInt) + if (value > 0) { + return getRow(value, resultRow) + } } } else { var pos = firstSlot(key) @@ -497,9 +499,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap */ def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = { if (isDense) { - val idx = (key - minKey).toInt - if (idx >=0 && key <= maxKey && array(idx) > 0) { - return valueIter(array(idx), resultRow) + if (key >= minKey && key <= maxKey) { + val value = array((key - minKey).toInt) + if (value > 0) { + return valueIter(value, resultRow) + } } } else { var pos = firstSlot(key) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 40864c80eb..1196f5ec7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -152,6 +152,51 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext { } } + test("LongToUnsafeRowMap with very wide range") { + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", "false"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false))) + + { + // SPARK-16740 + val keys = Seq(0L, Long.MaxValue, Long.MaxValue) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) === k) + } + map.free() + } + + + { + // SPARK-16802 + val keys = Seq(Long.MaxValue, Long.MaxValue - 10) + val map = new LongToUnsafeRowMap(taskMemoryManager, 1) + keys.foreach { k => + map.append(k, unsafeProj(InternalRow(k))) + } + map.optimize() + val row = unsafeProj(InternalRow(0L)).copy() + keys.foreach { k => + assert(map.getValue(k, row) eq row) + assert(row.getLong(0) === k) + } + assert(map.getValue(Long.MinValue, row) eq null) + map.free() + } + } + test("Spark-14521") { val ser = new KryoSerializer( (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance() |