aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/SizeEstimator.scala
blob: a3774fb0551274738fed0b1ce7dd9e47ac4f88a9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
package spark

import java.lang.reflect.Field
import java.lang.reflect.Modifier
import java.lang.reflect.{Array => JArray}
import java.util.IdentityHashMap
import java.util.concurrent.ConcurrentHashMap
import java.util.Random

import scala.collection.mutable.ArrayBuffer


/**
 * Estimates the sizes of Java objects (number of bytes of memory they occupy),
 * for use in memory-aware caches.
 *
 * Based on the following JavaWorld article:
 * http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html
 */
object SizeEstimator {
  private val OBJECT_SIZE  = 8 // Minimum size of a java.lang.Object
  private val POINTER_SIZE = 4 // Size of an object reference

  // Sizes of primitive types
  private val BYTE_SIZE    = 1
  private val BOOLEAN_SIZE = 1
  private val CHAR_SIZE    = 2
  private val SHORT_SIZE   = 2
  private val INT_SIZE     = 4
  private val LONG_SIZE    = 8
  private val FLOAT_SIZE   = 4
  private val DOUBLE_SIZE  = 8

  // A cache of ClassInfo objects for each class
  private val classInfos = new ConcurrentHashMap[Class[_], ClassInfo]
  classInfos.put(classOf[Object], new ClassInfo(OBJECT_SIZE, Nil))

  /**
   * The state of an ongoing size estimation. Contains a stack of objects
   * to visit as well as an IdentityHashMap of visited objects, and provides
   * utility methods for enqueueing new objects to visit.
   */
  private class SearchState {
    val visited = new IdentityHashMap[AnyRef, AnyRef]
    val stack = new ArrayBuffer[AnyRef]
    var size = 0L

    def enqueue(obj: AnyRef) {
      if (obj != null && !visited.containsKey(obj)) {
        visited.put(obj, null)
        stack += obj
      }
    }

    def isFinished(): Boolean = stack.isEmpty

    def dequeue(): AnyRef = {
      val elem = stack.last
      stack.trimEnd(1)
      return elem
    }
  }

  /**
   * Cached information about each class. We remember two things: the
   * "shell size" of the class (size of all non-static fields plus the
   * java.lang.Object size), and any fields that are pointers to objects.
   */
  private class ClassInfo(
    val shellSize: Long,
    val pointerFields: List[Field]) {}

  def estimate(obj: AnyRef): Long = {
    val state = new SearchState
    state.enqueue(obj)
    while (!state.isFinished) {
      visitSingleObject(state.dequeue(), state)
    }
    return state.size
  }

  private def visitSingleObject(obj: AnyRef, state: SearchState) {
    val cls = obj.getClass
    if (cls.isArray) {
      visitArray(obj, cls, state)
    } else {
      val classInfo = getClassInfo(cls)
      state.size += classInfo.shellSize
      for (field <- classInfo.pointerFields) {
        state.enqueue(field.get(obj))
      }
    }
  }

  private def visitArray(array: AnyRef, cls: Class[_], state: SearchState) {
    val length = JArray.getLength(array)
    val elementClass = cls.getComponentType
    if (elementClass.isPrimitive) {
      state.size += length * primitiveSize(elementClass)
    } else {
      state.size += length * POINTER_SIZE
      if (length <= 100) {
        for (i <- 0 until length) {
          state.enqueue(JArray.get(array, i))
        }
      } else {
        // Estimate the size of a large array by sampling elements.
        // TODO: Add a config setting for turning this off?
        var size = 0.0
        val rand = new Random(42)
        for (i <- 0 until 100) {
          val elem = JArray.get(array, rand.nextInt(length))
          size += SizeEstimator.estimate(elem)
        }
        state.size += ((length / 100.0) * size).toLong
      }
    }
  }

  private def primitiveSize(cls: Class[_]): Long = {
    if (cls == classOf[Byte])
      BYTE_SIZE
    else if (cls == classOf[Boolean])
      BOOLEAN_SIZE
    else if (cls == classOf[Char])
      CHAR_SIZE
    else if (cls == classOf[Short])
      SHORT_SIZE
    else if (cls == classOf[Int])
      INT_SIZE
    else if (cls == classOf[Long])
      LONG_SIZE
    else if (cls == classOf[Float])
      FLOAT_SIZE
    else if (cls == classOf[Double])
      DOUBLE_SIZE
    else throw new IllegalArgumentException(
      "Non-primitive class " + cls + " passed to primitiveSize()")
  }

  /**
   * Get or compute the ClassInfo for a given class.
   */
  private def getClassInfo(cls: Class[_]): ClassInfo = {
    // Check whether we've already cached a ClassInfo for this class
    val info = classInfos.get(cls)
    if (info != null) {
      return info
    }
    
    val parent = getClassInfo(cls.getSuperclass)
    var shellSize = parent.shellSize
    var pointerFields = parent.pointerFields

    for (field <- cls.getDeclaredFields) {
      if (!Modifier.isStatic(field.getModifiers)) {
        val fieldClass = field.getType
        if (fieldClass.isPrimitive) {
          shellSize += primitiveSize(fieldClass)
        } else {
          field.setAccessible(true) // Enable future get()'s on this field
          shellSize += POINTER_SIZE
          pointerFields = field :: pointerFields
        }
      }
    }

    // Create and cache a new ClassInfo
    val newInfo = new ClassInfo(shellSize, pointerFields)
    classInfos.put(cls, newInfo)
    return newInfo
  }
}