diff options
Diffstat (limited to 'core/src')
5 files changed, 154 insertions, 42 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 77a594a3e4..1b2e5417e7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -169,10 +169,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: } val mergeCombiners: (CoGroupCombiner, CoGroupCombiner) => CoGroupCombiner = (combiner1, combiner2) => { - combiner1.zipAll(combiner2, new CoGroup, new CoGroup).map { - case (v1, v2) => v1 ++ v2 + combiner1.zip(combiner2).map { case (v1, v2) => v1 ++ v2 } } - } new ExternalAppendOnlyMap[K, CoGroupValue, CoGroupCombiner]( createCombiner, mergeValue, mergeCombiners) } diff --git a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala index a32416afae..d2a9574a71 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/AppendOnlyMap.scala @@ -48,10 +48,14 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi private var haveNullValue = false private var nullValue: V = null.asInstanceOf[V] + // Triggered by destructiveSortedIterator; the underlying data array may no longer be used + private var destroyed = false + private val LOAD_FACTOR = 0.7 /** Get the value for a given key */ def apply(key: K): V = { + checkValidityOrThrowException() val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { return nullValue @@ -75,6 +79,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi /** Set the value for a key */ def update(key: K, value: V): Unit = { + checkValidityOrThrowException() val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { if (!haveNullValue) { @@ -109,6 +114,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi * for key, if any, or null otherwise. Returns the newly updated value. */ def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + checkValidityOrThrowException() val k = key.asInstanceOf[AnyRef] if (k.eq(null)) { if (!haveNullValue) { @@ -142,35 +148,38 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } /** Iterator method from Iterable */ - override def iterator: Iterator[(K, V)] = new Iterator[(K, V)] { - var pos = -1 + override def iterator: Iterator[(K, V)] = { + checkValidityOrThrowException() + new Iterator[(K, V)] { + var pos = -1 - /** Get the next value we should return from next(), or null if we're finished iterating */ - def nextValue(): (K, V) = { - if (pos == -1) { // Treat position -1 as looking at the null value - if (haveNullValue) { - return (null.asInstanceOf[K], nullValue) + /** Get the next value we should return from next(), or null if we're finished iterating */ + def nextValue(): (K, V) = { + if (pos == -1) { // Treat position -1 as looking at the null value + if (haveNullValue) { + return (null.asInstanceOf[K], nullValue) + } + pos += 1 } - pos += 1 - } - while (pos < capacity) { - if (!data(2 * pos).eq(null)) { - return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + while (pos < capacity) { + if (!data(2 * pos).eq(null)) { + return (data(2 * pos).asInstanceOf[K], data(2 * pos + 1).asInstanceOf[V]) + } + pos += 1 } - pos += 1 + null } - null - } - override def hasNext: Boolean = nextValue() != null + override def hasNext: Boolean = nextValue() != null - override def next(): (K, V) = { - val value = nextValue() - if (value == null) { - throw new NoSuchElementException("End of iterator") + override def next(): (K, V) = { + val value = nextValue() + if (value == null) { + throw new NoSuchElementException("End of iterator") + } + pos += 1 + value } - pos += 1 - value } } @@ -238,12 +247,14 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi if (highBit == n) n else highBit << 1 } - /** Return an iterator of the map in sorted order. This provides a way to sort the map without - * using additional memory, at the expense of destroying the validity of the map. - */ + /** + * Return an iterator of the map in sorted order. This provides a way to sort the map without + * using additional memory, at the expense of destroying the validity of the map. + */ def destructiveSortedIterator(cmp: Comparator[(K, V)]): Iterator[(K, V)] = { - var keyIndex, newIndex = 0 + destroyed = true // Pack KV pairs into the front of the underlying array + var keyIndex, newIndex = 0 while (keyIndex < capacity) { if (data(2 * keyIndex) != null) { data(newIndex) = (data(2 * keyIndex), data(2 * keyIndex + 1)) @@ -251,23 +262,36 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } keyIndex += 1 } - assert(newIndex == curSize) + assert(curSize == newIndex + (if (haveNullValue) 1 else 0)) + // Sort by the given ordering val rawOrdering = new Comparator[AnyRef] { def compare(x: AnyRef, y: AnyRef): Int = { cmp.compare(x.asInstanceOf[(K, V)], y.asInstanceOf[(K, V)]) } } - util.Arrays.sort(data, 0, curSize, rawOrdering) + util.Arrays.sort(data, 0, newIndex, rawOrdering) new Iterator[(K, V)] { var i = 0 - def hasNext = i < curSize + var nullValueReady = haveNullValue + def hasNext: Boolean = (i < newIndex || nullValueReady) def next(): (K, V) = { - val item = data(i).asInstanceOf[(K, V)] - i += 1 - item + if (nullValueReady) { + nullValueReady = false + (null.asInstanceOf[K], nullValue) + } else { + val item = data(i).asInstanceOf[(K, V)] + i += 1 + item + } } } } + + private def checkValidityOrThrowException(): Unit = { + if (destroyed) { + throw new IllegalStateException("Map state is invalid from destructive sorting!") + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 0e8f46cfc7..492b4fc7c6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -23,7 +23,7 @@ import java.util.Comparator import scala.collection.mutable.{ArrayBuffer, PriorityQueue} import scala.reflect.ClassTag -import org.apache.spark.SparkEnv +import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.serializer.Serializer import org.apache.spark.storage.{DiskBlockManager, DiskBlockObjectWriter} @@ -106,14 +106,15 @@ private[spark] class SpillableAppendOnlyMap[K, V, G: ClassTag, C: ClassTag]( createCombiner: G => C, serializer: Serializer, diskBlockManager: DiskBlockManager) - extends Iterable[(K, C)] with Serializable { + extends Iterable[(K, C)] with Serializable with Logging { import SpillableAppendOnlyMap._ private var currentMap = new SizeTrackingAppendOnlyMap[K, G] private val oldMaps = new ArrayBuffer[DiskKGIterator] - private val memoryThreshold = { - val bufferSize = System.getProperty("spark.shuffle.buffer.mb", "1024").toLong * 1024 * 1024 + + private val memoryThresholdMB = { + val bufferSize = System.getProperty("spark.shuffle.buffer.mb", "1024").toLong val bufferPercent = System.getProperty("spark.shuffle.buffer.fraction", "0.8").toFloat bufferSize * bufferPercent } @@ -121,18 +122,22 @@ private[spark] class SpillableAppendOnlyMap[K, V, G: ClassTag, C: ClassTag]( System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 private val comparator = new KeyGroupComparator[K, G] private val ser = serializer.newInstance() + private var spillCount = 0 def insert(key: K, value: V): Unit = { val update: (Boolean, G) => G = (hadVal, oldVal) => { if (hadVal) mergeValue(oldVal, value) else createGroup(value) } currentMap.changeValue(key, update) - if (currentMap.estimateSize() > memoryThreshold) { + if (currentMap.estimateSize() > memoryThresholdMB * 1024 * 1024) { spill() } } private def spill(): Unit = { + spillCount += 1 + logWarning(s"In-memory KV map exceeded threshold of $memoryThresholdMB MB!") + logWarning(s"Spilling to disk ($spillCount time"+(if (spillCount > 1) "s" else "")+" so far)") val (blockId, file) = diskBlockManager.createIntermediateBlock val writer = new DiskBlockObjectWriter(blockId, file, serializer, fileBufferSize, identity) try { @@ -257,14 +262,15 @@ private[spark] class SpillableAppendOnlyMap[K, V, G: ClassTag, C: ClassTag]( // Iterate through (K, G) pairs in sorted order from an on-disk map private class DiskKGIterator(file: File) extends Iterator[(K, G)] { - val in = ser.deserializeStream(new FileInputStream(file)) + val fstream = new FileInputStream(file) + val dstream = ser.deserializeStream(fstream) var nextItem: Option[(K, G)] = None var eof = false def readNextItem(): Option[(K, G)] = { if (!eof) { try { - return Some(in.readObject().asInstanceOf[(K, G)]) + return Some(dstream.readObject().asInstanceOf[(K, G)]) } catch { case e: EOFException => eof = true @@ -296,6 +302,8 @@ private[spark] class SpillableAppendOnlyMap[K, V, G: ClassTag, C: ClassTag]( // TODO: Ensure this gets called even if the iterator isn't drained. def cleanup() { + fstream.close() + dstream.close() file.delete() } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala index 7e7aa7800d..71b936b0df 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/AppendOnlyMapSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite +import java.util.Comparator class AppendOnlyMapSuite extends FunSuite { test("initialization") { @@ -151,4 +152,47 @@ class AppendOnlyMapSuite extends FunSuite { assert(map("" + i) === "" + i) } } + + test("destructive sort") { + val map = new AppendOnlyMap[String, String]() + for (i <- 1 to 100) { + map("" + i) = "" + i + } + map.update(null, "happy new year!") + + try { + map.apply("1") + map.update("1", "2013") + map.changeValue("1", (hadValue, oldValue) => "2014") + map.iterator + } catch { + case e: IllegalStateException => fail() + } + + val it = map.destructiveSortedIterator(new Comparator[(String, String)] { + def compare(kv1: (String, String), kv2: (String, String)): Int = { + val x = if (kv1 != null && kv1._1 != null) kv1._1.toInt else Int.MinValue + val y = if (kv2 != null && kv2._1 != null) kv2._1.toInt else Int.MinValue + x.compareTo(y) + } + }) + + // Should be sorted by key + assert(it.hasNext) + var previous = it.next() + assert(previous == (null, "happy new year!")) + previous = it.next() + assert(previous == ("1", "2014")) + while (it.hasNext) { + val kv = it.next() + assert(kv._1.toInt > previous._1.toInt) + previous = kv + } + + // All subsequent calls to apply, update, changeValue and iterator should throw exception + intercept[IllegalStateException] { map.apply("1") } + intercept[IllegalStateException] { map.update("1", "2013") } + intercept[IllegalStateException] { map.changeValue("1", (hadValue, oldValue) => "2014") } + intercept[IllegalStateException] { map.iterator } + } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 3bc88caaf3..baf94b4728 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -113,6 +113,44 @@ class ExternalAppendOnlyMapSuite extends FunSuite with BeforeAndAfter with Local assert(kv1._2 == kv2._2 && kv2._2 == kv3._2) } + test("null keys and values") { + val map = new ExternalAppendOnlyMap[Int, Int, ArrayBuffer[Int]](createCombiner, + mergeValue, mergeCombiners) + map.insert(1, 5) + map.insert(2, 6) + map.insert(3, 7) + assert(map.size === 3) + assert(map.iterator.toSet == Set[(Int, Seq[Int])]( + (1, Seq[Int](5)), + (2, Seq[Int](6)), + (3, Seq[Int](7)) + )) + + // Null keys + val nullInt = null.asInstanceOf[Int] + map.insert(nullInt, 8) + assert(map.size === 4) + assert(map.iterator.toSet == Set[(Int, Seq[Int])]( + (1, Seq[Int](5)), + (2, Seq[Int](6)), + (3, Seq[Int](7)), + (nullInt, Seq[Int](8)) + )) + + // Null values + map.insert(4, nullInt) + map.insert(nullInt, nullInt) + assert(map.size === 5) + val result = map.iterator.toSet[(Int, ArrayBuffer[Int])].map(kv => (kv._1, kv._2.toSet)) + assert(result == Set[(Int, Set[Int])]( + (1, Set[Int](5)), + (2, Set[Int](6)), + (3, Set[Int](7)), + (4, Set[Int](nullInt)), + (nullInt, Set[Int](nullInt, 8)) + )) + } + test("simple aggregator") { // reduceByKey val rdd = sc.parallelize(1 to 10).map(i => (i%2, 1)) |