aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-08-04 11:20:17 -0700
committerDavies Liu <davies.liu@gmail.com>2016-08-04 11:20:17 -0700
commit9d4e6212fa8d434089d32bff1217f39919abe44d (patch)
tree9dbfdd0832bd98d7b497462acf9c5a916d5bbab7
parent9d7a47406ed538f0005cdc7a62bc6e6f20634815 (diff)
downloadspark-9d4e6212fa8d434089d32bff1217f39919abe44d.tar.gz
spark-9d4e6212fa8d434089d32bff1217f39919abe44d.tar.bz2
spark-9d4e6212fa8d434089d32bff1217f39919abe44d.zip
[SPARK-16802] [SQL] fix overflow in LongToUnsafeRowMap
## What changes were proposed in this pull request? This patch fix the overflow in LongToUnsafeRowMap when the range of key is very wide (the key is much much smaller then minKey, for example, key is Long.MinValue, minKey is > 0). ## How was this patch tested? Added regression test (also for SPARK-16740) Author: Davies Liu <davies@databricks.com> Closes #14464 from davies/fix_overflow.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala45
2 files changed, 55 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index cf4454c033..08975733ff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -459,9 +459,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def getValue(key: Long, resultRow: UnsafeRow): UnsafeRow = {
if (isDense) {
- val idx = (key - minKey).toInt
- if (idx >= 0 && key <= maxKey && array(idx) > 0) {
- return getRow(array(idx), resultRow)
+ if (key >= minKey && key <= maxKey) {
+ val value = array((key - minKey).toInt)
+ if (value > 0) {
+ return getRow(value, resultRow)
+ }
}
} else {
var pos = firstSlot(key)
@@ -497,9 +499,11 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
*/
def get(key: Long, resultRow: UnsafeRow): Iterator[UnsafeRow] = {
if (isDense) {
- val idx = (key - minKey).toInt
- if (idx >=0 && key <= maxKey && array(idx) > 0) {
- return valueIter(array(idx), resultRow)
+ if (key >= minKey && key <= maxKey) {
+ val value = array((key - minKey).toInt)
+ if (value > 0) {
+ return valueIter(value, resultRow)
+ }
}
} else {
var pos = firstSlot(key)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index 40864c80eb..1196f5ec7b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -152,6 +152,51 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
}
}
+ test("LongToUnsafeRowMap with very wide range") {
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", "false"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
+
+ {
+ // SPARK-16740
+ val keys = Seq(0L, Long.MaxValue, Long.MaxValue)
+ val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+ keys.foreach { k =>
+ map.append(k, unsafeProj(InternalRow(k)))
+ }
+ map.optimize()
+ val row = unsafeProj(InternalRow(0L)).copy()
+ keys.foreach { k =>
+ assert(map.getValue(k, row) eq row)
+ assert(row.getLong(0) === k)
+ }
+ map.free()
+ }
+
+
+ {
+ // SPARK-16802
+ val keys = Seq(Long.MaxValue, Long.MaxValue - 10)
+ val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+ keys.foreach { k =>
+ map.append(k, unsafeProj(InternalRow(k)))
+ }
+ map.optimize()
+ val row = unsafeProj(InternalRow(0L)).copy()
+ keys.foreach { k =>
+ assert(map.getValue(k, row) eq row)
+ assert(row.getLong(0) === k)
+ }
+ assert(map.getValue(Long.MinValue, row) eq null)
+ map.free()
+ }
+ }
+
test("Spark-14521") {
val ser = new KryoSerializer(
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()