aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala45
2 files changed, 55 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cf4454c033..08975733ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -459,9 +459,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
if (isDense) {
- val idx = (key - minKey).toInt
- if (idx >= 0 && key <= maxKey && array(idx) > 0) {
- return getRow(array(idx), resultRow)
+ if (key >= minKey && key <= maxKey) {
+ val value = array((key - minKey).toInt)
+ if (value > 0) {
+ return getRow(value, resultRow)
+ }
}
} else {
var pos = firstSlot(key)
@@ -497,9 +499,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
if (isDense) {
- val idx = (key - minKey).toInt
- if (idx >=0 && key <= maxKey && array(idx) > 0) {
- return valueIter(array(idx), resultRow)
+ if (key >= minKey && key <= maxKey) {
+ val value = array((key - minKey).toInt)
+ if (value > 0) {
+ return valueIter(value, resultRow)
+ }
}
} else {
var pos = firstSlot(key)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 40864c80eb..1196f5ec7b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -152,6 +152,51 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
}
}
+ test("LongToUnsafeRowMap with very wide range") {
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
+
+ {
+ // SPARK-16740
+ val keys = Seq(0L, Long.MaxValue, Long.MaxValue)
+ val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+ keys.foreach { k =>
+ map.append(k, unsafeProj(InternalRow(k)))
+ }
+ map.optimize()
+ val row = unsafeProj(InternalRow(0L)).copy()
+ keys.foreach { k =>
+ assert(map.getValue(k, row) eq row)
+ assert(row.getLong(0) === k)
+ }
+ map.free()
+ }
+
+
+ {
+ // SPARK-16802
+ val keys = Seq(Long.MaxValue, Long.MaxValue - 10)
+ val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+ keys.foreach { k =>
+ map.append(k, unsafeProj(InternalRow(k)))
+ }
+ map.optimize()
+ val row = unsafeProj(InternalRow(0L)).copy()
+ keys.foreach { k =>
+ assert(map.getValue(k, row) eq row)
+ assert(row.getLong(0) === k)
+ }
+ assert(map.getValue(Long.MinValue, row) eq null)
+ map.free()
+ }
+ }
+
test("Spark-14521") {
val ser = new KryoSerializer(
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()