diff options
author | Aaron Davidson <aaron@databricks.com> | 2013-12-26 01:28:18 -0800 |
---|---|---|
committer | Andrew Or <andrewor14@gmail.com> | 2013-12-26 23:40:07 -0800 |
commit | 804beb43bebe50e88814c0ca702a51571cd044e7 (patch) | |
tree | c4c64f2c7972effda10c967e05a6697fa0b78c60 | |
parent | 7ad4408255e37f95e545d9c21a4460cbf98c05dd (diff) | |
download | spark-804beb43bebe50e88814c0ca702a51571cd044e7.tar.gz spark-804beb43bebe50e88814c0ca702a51571cd044e7.tar.bz2 spark-804beb43bebe50e88814c0ca702a51571cd044e7.zip |
SamplingSizeTracker + Map + test suite
5 files changed, 204 insertions, 11 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala index 8bb4ee3bfa..899cd6ac14 100644 --- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala @@ -190,7 +190,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } /** Double the table's size and re-hash everything */ - private def growTable() { + protected def growTable() { val newCapacity = capacity * 2 if (newCapacity >= (1 << 30)) { // We can't make the table this big because we want an array of 2x diff --git a/core/src/main/scala/org/apache/spark/util/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/ExternalAppendOnlyMap.scala index 413f83862d..b97b28282a 100644 --- a/core/src/main/scala/org/apache/spark/util/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/ExternalAppendOnlyMap.scala @@ -18,12 +18,10 @@ package org.apache.spark.util import java.io._ -import java.text.DecimalFormat -import scala.Some -import scala.Predef._ import scala.collection.mutable.{ArrayBuffer, PriorityQueue} -import scala.util.Random + +import org.apache.spark.util.collection.SizeTrackingAppendOnlyMap /** * A wrapper for SpillableAppendOnlyMap that handles two cases: @@ -88,8 +86,7 @@ class SpillableAppendOnlyMap[K, V, M, C]( memoryThresholdMB: Long = 1024) extends Iterable[(K, C)] with Serializable { - var currentMap = new AppendOnlyMap[K, M] - var sizeTracker = new SamplingSizeTracker(currentMap) + var currentMap = new SizeTrackingAppendOnlyMap[K, M] var oldMaps = new ArrayBuffer[DiskIterator] def insert(key: K, value: V): Unit = { @@ -97,8 +94,8 @@ class SpillableAppendOnlyMap[K, V, M, C]( if (hadVal) mergeValue(oldVal, value) else createGroup(value) } currentMap.changeValue(key, update) - sizeTracker.updateMade() - if (sizeTracker.estimateSize() > memoryThresholdMB * 1024 * 1024) { + // TODO: Make sure we're only using some % of the actual threshold due to error + if (currentMap.estimateSize() > memoryThresholdMB * 1024 * 1024) { spill() } } @@ -109,8 +106,7 @@ class SpillableAppendOnlyMap[K, V, M, C]( val sortedMap = currentMap.iterator.toList.sortBy(kv => kv._1.hashCode()) sortedMap.foreach(out.writeObject) out.close() - currentMap = new AppendOnlyMap[K, M] - sizeTracker = new SamplingSizeTracker(currentMap) + currentMap = new SizeTrackingAppendOnlyMap[K, M] oldMaps.append(new DiskIterator(file)) } diff --git a/core/src/main/scala/org/apache/spark/util/SamplingSizeTracker.scala b/core/src/main/scala/org/apache/spark/util/SamplingSizeTracker.scala new file mode 100644 index 0000000000..2262b7d1be --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/SamplingSizeTracker.scala @@ -0,0 +1,67 @@ +package org.apache.spark.util + +import org.apache.spark.util.SamplingSizeTracker.Sample + +/** + * Estimates the size of an object as it grows, in bytes. + * We sample with a slow exponential back-off using the SizeEstimator to amortize the time, + * as each call to SizeEstimator can take a sizable amount of time (order of a few milliseconds). + * + * Users should call updateMade() every time their object is updated with new data, or + * flushSamples() if there is a non-linear change in object size (otherwise linear is assumed). + * Not threadsafe. + */ +class SamplingSizeTracker(obj: AnyRef) { + /** + * Controls the base of the exponential which governs the rate of sampling. + * E.g., a value of 2 would mean we sample at 1, 2, 4, 8, ... elements. + */ + private val SAMPLE_GROWTH_RATE = 1.1 + + private var lastLastSample: Sample = _ + private var lastSample: Sample = _ + + private var numUpdates: Long = _ + private var nextSampleNum: Long = _ + + flushSamples() + + /** Called after a non-linear change in the tracked object. Takes a new sample. */ + def flushSamples() { + numUpdates = 0 + nextSampleNum = 1 + // Throw out both prior samples to avoid overestimating delta. + lastSample = Sample(SizeEstimator.estimate(obj), 0) + lastLastSample = lastSample + } + + /** To be called after an update to the tracked object. Amortized O(1) time. */ + def updateMade() { + numUpdates += 1 + if (nextSampleNum == numUpdates) { + lastLastSample = lastSample + lastSample = Sample(SizeEstimator.estimate(obj), numUpdates) + nextSampleNum = math.ceil(numUpdates * SAMPLE_GROWTH_RATE).toLong + } + } + + /** Estimates the current size of the tracked object. O(1) time. */ + def estimateSize(): Long = { + val interpolatedDelta = + if (lastLastSample != null && lastLastSample != lastSample) { + (lastSample.size - lastLastSample.size).toDouble / + (lastSample.numUpdates - lastLastSample.numUpdates) + } else if (lastSample.numUpdates > 0) { + lastSample.size.toDouble / lastSample.numUpdates + } else { + 0 + } + val extrapolatedDelta = interpolatedDelta * (numUpdates - lastSample.numUpdates) + val estimate = lastSample.size + extrapolatedDelta + math.max(0, estimate).toLong + } +} + +object SamplingSizeTracker { + case class Sample(size: Long, numUpdates: Long) +} diff --git a/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala new file mode 100644 index 0000000000..2b2417efd9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/collection/SizeTrackingAppendOnlyMap.scala @@ -0,0 +1,27 @@ +package org.apache.spark.util.collection + +import org.apache.spark.util.{AppendOnlyMap, SamplingSizeTracker} + +/** Append-only map that keeps track of its estimated size in bytes. */ +class SizeTrackingAppendOnlyMap[K, V] extends AppendOnlyMap[K, V] { + + val sizeTracker = new SamplingSizeTracker(this) + + def estimateSize() = sizeTracker.estimateSize() + + override def update(key: K, value: V): Unit = { + super.update(key, value) + sizeTracker.updateMade() + } + + override def changeValue(key: K, updateFunc: (Boolean, V) => V): V = { + val newValue = super.changeValue(key, updateFunc) + sizeTracker.updateMade() + newValue + } + + override protected def growTable() { + super.growTable() + sizeTracker.flushSamples() + } +} diff --git a/core/src/test/scala/org/apache/spark/util/SamplingSizeTrackerSuite.scala b/core/src/test/scala/org/apache/spark/util/SamplingSizeTrackerSuite.scala new file mode 100644 index 0000000000..bd3ff5ff41 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/SamplingSizeTrackerSuite.scala @@ -0,0 +1,103 @@ +package org.apache.spark.util + +import scala.util.Random + +import org.scalatest.{BeforeAndAfterAll, FunSuite} + +import org.apache.spark.util.SamplingSizeTrackerSuite.LargeDummyClass +import org.apache.spark.util.collection.SizeTrackingAppendOnlyMap + +class SamplingSizeTrackerSuite extends FunSuite with BeforeAndAfterAll { + val NORMAL_ERROR = 0.20 + val HIGH_ERROR = 0.30 + + test("fixed size insertions") { + testWith[Int, Long](10000, i => (i, i.toLong)) + testWith[Int, (Long, Long)](10000, i => (i, (i.toLong, i.toLong))) + testWith[Int, LargeDummyClass](10000, i => (i, new LargeDummyClass())) + } + + test("variable size insertions") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testWith[Int, String](10000, i => (i, randString(0, 10))) + testWith[Int, String](10000, i => (i, randString(0, 100))) + testWith[Int, String](10000, i => (i, randString(90, 100))) + } + + test("updates") { + val rand = new Random(123456789) + def randString(minLen: Int, maxLen: Int): String = { + "a" * (rand.nextInt(maxLen - minLen) + minLen) + } + testWith[String, Int](10000, i => (randString(0, 10000), i)) + } + + def testWith[K, V](numElements: Int, makeElement: (Int) => (K, V)) { + val map = new SizeTrackingAppendOnlyMap[K, V]() + for (i <- 0 until numElements) { + val (k, v) = makeElement(i) + map(k) = v + expectWithinError(map, map.estimateSize(), if (i < 32) HIGH_ERROR else NORMAL_ERROR) + } + } + + def expectWithinError(obj: AnyRef, estimatedSize: Long, error: Double) { + val betterEstimatedSize = SizeEstimator.estimate(obj) + assert(betterEstimatedSize * (1 - error) < estimatedSize, + s"Estimated size $estimatedSize was less than expected size $betterEstimatedSize") + assert(betterEstimatedSize * (1 + 2 * error) > estimatedSize, + s"Estimated size $estimatedSize was greater than expected size $betterEstimatedSize") + } +} + +object SamplingSizeTrackerSuite { + // Speed test, for reproducibility of results. + // These could be highly non-deterministic in general, however. + // Results: + // AppendOnlyMap: 30 ms + // SizeTracker: 45 ms + // SizeEstimator: 1500 ms + def main(args: Array[String]) { + val numElements = 100000 + + val baseTimes = for (i <- 0 until 3) yield time { + val map = new AppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + } + } + + val sampledTimes = for (i <- 0 until 3) yield time { + val map = new SizeTrackingAppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + map.estimateSize() + } + } + + val unsampledTimes = for (i <- 0 until 3) yield time { + val map = new AppendOnlyMap[Int, LargeDummyClass]() + for (i <- 0 until numElements) { + map(i) = new LargeDummyClass() + SizeEstimator.estimate(map) + } + } + + println("Base: " + baseTimes) + println("SizeTracker (sampled): " + sampledTimes) + println("SizeEstimator (unsampled): " + unsampledTimes) + } + + def time(f: => Unit): Long = { + val start = System.currentTimeMillis() + f + System.currentTimeMillis() - start + } + + private class LargeDummyClass { + val arr = new Array[Int](100) + } +}
\ No newline at end of file |