diff options
author | Andrew Or <andrewor14@gmail.com> | 2014-07-03 10:26:50 -0700 |
---|---|---|
committer | Aaron Davidson <aaron@databricks.com> | 2014-07-03 10:26:50 -0700 |
commit | c480537739f9329ebfd580f09c69778e6c976366 (patch) | |
tree | 95526d3961b1aa35adc01cb2e652aff9532e9639 | |
parent | 3bbeca648985b32bdf1eedef779cb2817eb6dfa4 (diff) | |
download | spark-c480537739f9329ebfd580f09c69778e6c976366.tar.gz spark-c480537739f9329ebfd580f09c69778e6c976366.tar.bz2 spark-c480537739f9329ebfd580f09c69778e6c976366.zip |
[SPARK] Fix NPE for ExternalAppendOnlyMap
It did not handle null keys very gracefully before.
Author: Andrew Or <andrewor14@gmail.com>
Closes #1288 from andrewor14/fix-external and squashes the following commits:
312b8d8 [Andrew Or] Abstract key hash code
ed5adf9 [Andrew Or] Fix NPE for ExternalAppendOnlyMap
-rw-r--r-- | core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala | 30 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala | 27 |
2 files changed, 46 insertions, 11 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 288badd316..292d0962f4 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -252,7 +252,7 @@ class ExternalAppendOnlyMap[K, V, C]( if (it.hasNext) { var kc = it.next() kcPairs += kc - val minHash = kc._1.hashCode() + val minHash = getKeyHashCode(kc) while (it.hasNext && it.head._1.hashCode() == minHash) { kc = it.next() kcPairs += kc @@ -294,8 +294,9 @@ class ExternalAppendOnlyMap[K, V, C]( // Select a key from the StreamBuffer that holds the lowest key hash val minBuffer = mergeHeap.dequeue() val (minPairs, minHash) = (minBuffer.pairs, minBuffer.minKeyHash) - var (minKey, minCombiner) = minPairs.remove(0) - assert(minKey.hashCode() == minHash) + val minPair = minPairs.remove(0) + var (minKey, minCombiner) = minPair + assert(getKeyHashCode(minPair) == minHash) // For all other streams that may have this key (i.e. have the same minimum key hash), // merge in the corresponding value (if any) from that stream @@ -327,15 +328,16 @@ class ExternalAppendOnlyMap[K, V, C]( * StreamBuffers are ordered by the minimum key hash found across all of their own pairs. */ private class StreamBuffer( - val iterator: BufferedIterator[(K, C)], val pairs: ArrayBuffer[(K, C)]) + val iterator: BufferedIterator[(K, C)], + val pairs: ArrayBuffer[(K, C)]) extends Comparable[StreamBuffer] { def isEmpty = pairs.length == 0 // Invalid if there are no more pairs in this stream - def minKeyHash = { + def minKeyHash: Int = { assert(pairs.length > 0) - pairs.head._1.hashCode() + getKeyHashCode(pairs.head) } override def compareTo(other: StreamBuffer): Int = { @@ -422,10 +424,22 @@ class ExternalAppendOnlyMap[K, V, C]( } private[spark] object ExternalAppendOnlyMap { + + /** + * Return the key hash code of the given (key, combiner) pair. + * If the key is null, return a special hash code. + */ + private def getKeyHashCode[K, C](kc: (K, C)): Int = { + if (kc._1 == null) 0 else kc._1.hashCode() + } + + /** + * A comparator for (key, combiner) pairs based on their key hash codes. + */ private class KCComparator[K, C] extends Comparator[(K, C)] { def compare(kc1: (K, C), kc2: (K, C)): Int = { - val hash1 = kc1._1.hashCode() - val hash2 = kc2._1.hashCode() + val hash1 = getKeyHashCode(kc1) + val hash2 = getKeyHashCode(kc2) if (hash1 < hash2) -1 else if (hash1 == hash2) 0 else 1 } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index deb7809535..428822949c 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -334,8 +334,8 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { conf.set("spark.shuffle.memoryFraction", "0.001") sc = new SparkContext("local-cluster[1,1,512]", "test", conf) - val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, - mergeValue, mergeCombiners) + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]]( + createCombiner, mergeValue, mergeCombiners) (1 to 100000).foreach { i => map.insert(i, i) } map.insert(Int.MaxValue, Int.MaxValue) @@ -346,11 +346,32 @@ class ExternalAppendOnlyMapSuite extends FunSuite with LocalSparkContext { it.next() } } + + test("spilling with null keys and values") { + val conf = new SparkConf(true) + conf.set("spark.shuffle.memoryFraction", "0.001") + sc = new SparkContext("local-cluster[1,1,512]", "test", conf) + + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]]( + createCombiner, mergeValue, mergeCombiners) + + (1 to 100000).foreach { i => map.insert(i, i) } + map.insert(null.asInstanceOf[Int], 1) + map.insert(1, null.asInstanceOf[Int]) + map.insert(null.asInstanceOf[Int], null.asInstanceOf[Int]) + + val it = map.iterator + while (it.hasNext) { + // Should not throw NullPointerException + it.next() + } + } + } /** * A dummy class that always returns the same hash code, to easily test hash collisions */ -case class FixedHashObject(val v: Int, val h: Int) extends Serializable { +case class FixedHashObject(v: Int, h: Int) extends Serializable { override def hashCode(): Int = h } |