aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/util/SizeEstimator.scala45
-rw-r--r--core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala18
2 files changed, 48 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
index 26ffbf9350..4dd7ab9e07 100644
--- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
+++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala
@@ -179,7 +179,7 @@ private[spark] object SizeEstimator extends Logging {
}
// Estimate the size of arrays larger than ARRAY_SIZE_FOR_SAMPLING by sampling.
- private val ARRAY_SIZE_FOR_SAMPLING = 200
+ private val ARRAY_SIZE_FOR_SAMPLING = 400
private val ARRAY_SAMPLE_SIZE = 100 // should be lower than ARRAY_SIZE_FOR_SAMPLING
private def visitArray(array: AnyRef, arrayClass: Class[_], state: SearchState) {
@@ -204,25 +204,40 @@ private[spark] object SizeEstimator extends Logging {
}
} else {
// Estimate the size of a large array by sampling elements without replacement.
- var size = 0.0
+ // To exclude the shared objects that the array elements may link, sample twice
+ // and use the min one to caculate array size.
val rand = new Random(42)
- val drawn = new OpenHashSet[Int](ARRAY_SAMPLE_SIZE)
- var numElementsDrawn = 0
- while (numElementsDrawn < ARRAY_SAMPLE_SIZE) {
- var index = 0
- do {
- index = rand.nextInt(length)
- } while (drawn.contains(index))
- drawn.add(index)
- val elem = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef]
- size += SizeEstimator.estimate(elem, state.visited)
- numElementsDrawn += 1
- }
- state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong
+ val drawn = new OpenHashSet[Int](2 * ARRAY_SAMPLE_SIZE)
+ val s1 = sampleArray(array, state, rand, drawn, length)
+ val s2 = sampleArray(array, state, rand, drawn, length)
+ val size = math.min(s1, s2)
+ state.size += math.max(s1, s2) +
+ (size * ((length - ARRAY_SAMPLE_SIZE) / (ARRAY_SAMPLE_SIZE))).toLong
}
}
}
+ private def sampleArray(
+ array: AnyRef,
+ state: SearchState,
+ rand: Random,
+ drawn: OpenHashSet[Int],
+ length: Int): Long = {
+ var size = 0L
+ for (i <- 0 until ARRAY_SAMPLE_SIZE) {
+ var index = 0
+ do {
+ index = rand.nextInt(length)
+ } while (drawn.contains(index))
+ drawn.add(index)
+ val obj = ScalaRunTime.array_apply(array, index).asInstanceOf[AnyRef]
+ if (obj != null) {
+ size += SizeEstimator.estimate(obj, state.visited).toLong
+ }
+ }
+ size
+ }
+
private def primitiveSize(cls: Class[_]): Long = {
if (cls == classOf[Byte]) {
BYTE_SIZE
diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
index 67a9f75ff2..28915bd533 100644
--- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.util
+import scala.collection.mutable.ArrayBuffer
+
import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, FunSuite, PrivateMethodTester}
class DummyClass1 {}
@@ -96,6 +98,22 @@ class SizeEstimatorSuite
// Past size 100, our samples 100 elements, but we should still get the right size.
assertResult(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)))
+
+ val arr = new Array[Char](100000)
+ assertResult(200016)(SizeEstimator.estimate(arr))
+ assertResult(480032)(SizeEstimator.estimate(Array.fill(10000)(new DummyString(arr))))
+
+ val buf = new ArrayBuffer[DummyString]()
+ for (i <- 0 until 5000) {
+ buf.append(new DummyString(new Array[Char](10)))
+ }
+ assertResult(340016)(SizeEstimator.estimate(buf.toArray))
+
+ for (i <- 0 until 5000) {
+ buf.append(new DummyString(arr))
+ }
+ assertResult(683912)(SizeEstimator.estimate(buf.toArray))
+
// If an array contains the *same* element many times, we should only count it once.
val d1 = new DummyClass1
// 10 pointers plus 8-byte object