summaryrefslogtreecommitdiff
path: root/src/reflect/scala/reflect/internal/util/WeakHashSet.scala
blob: 83d2a3453b1b512b4866a2db55e4a0bcdbb3968f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
package scala
package reflect.internal.util

import java.lang.ref.{WeakReference, ReferenceQueue}
import scala.annotation.tailrec
import scala.collection.generic.Clearable
import scala.collection.mutable.{Set => MSet}

/**
 * A HashSet where the elements are stored weakly. Elements in this set are eligible for GC if no other
 * hard references are associated with them. Its primary use case is as a canonical reference
 * identity holder (aka "hash-consing") via findEntryOrUpdate
 *
 * This Set implementation cannot hold null. Any attempt to put a null in it will result in a NullPointerException
 *
 * This set implementation is not in general thread safe without external concurrency control. However it behaves
 * properly when GC concurrently collects elements in this set.
 */
final class WeakHashSet[A <: AnyRef](val initialCapacity: Int, val loadFactor: Double) extends Set[A] with Function1[A, Boolean] with MSet[A] {

  import WeakHashSet._

  def this() = this(initialCapacity = WeakHashSet.defaultInitialCapacity, loadFactor = WeakHashSet.defaultLoadFactor)

  type This = WeakHashSet[A]

  /**
   * queue of Entries that hold elements scheduled for GC
   * the removeStaleEntries() method works through the queue to remove
   * stale entries from the table
   */
  private[this] val queue = new ReferenceQueue[A]

  /**
   * the number of elements in this set
   */
  private[this] var count = 0

  /**
   * from a specified initial capacity compute the capacity we'll use as being the next
   * power of two equal to or greater than the specified initial capacity
   */
  private def computeCapacity = {
    if (initialCapacity < 0) throw new IllegalArgumentException("initial capacity cannot be less than 0")
    var candidate = 1
    while (candidate < initialCapacity) {
      candidate  *= 2
    }
    candidate
  }

  /**
   * the underlying table of entries which is an array of Entry linked lists
   */
  private[this] var table = new Array[Entry[A]](computeCapacity)

  /**
   * the limit at which we'll increase the size of the hash table
   */
  var threshhold = computeThreshHold

  private[this] def computeThreshHold: Int = (table.size * loadFactor).ceil.toInt

  /**
   * find the bucket associated with an element's hash code
   */
  private[this] def bucketFor(hash: Int): Int = {
    // spread the bits around to try to avoid accidental collisions using the
    // same algorithm as java.util.HashMap
    var h = hash
    h ^= h >>> 20 ^ h >>> 12
    h ^= h >>> 7 ^ h >>> 4

    // this is finding h % table.length, but takes advantage of the
    // fact that table length is a power of 2,
    // if you don't do bit flipping in your head, if table.length
    // is binary 100000.. (with n 0s) then table.length - 1
    // is 1111.. with n 1's.
    // In other words this masks on the last n bits in the hash
    h & (table.length - 1)
  }

  /**
   * remove a single entry from a linked list in a given bucket
   */
  private[this] def remove(bucket: Int, prevEntry: Entry[A], entry: Entry[A]) {
    prevEntry match {
      case null => table(bucket) = entry.tail
      case _ => prevEntry.tail = entry.tail
    }
    count -= 1
  }

  /**
   * remove entries associated with elements that have been gc'ed
   */
  private[this] def removeStaleEntries() {
    def poll(): Entry[A] = queue.poll().asInstanceOf[Entry[A]]

    @tailrec
    def queueLoop(): Unit = {
      val stale = poll()
      if (stale != null) {
        val bucket = bucketFor(stale.hash)

        @tailrec
        def linkedListLoop(prevEntry: Entry[A], entry: Entry[A]): Unit = if (stale eq entry) remove(bucket, prevEntry, entry)
        else if (entry != null) linkedListLoop(entry, entry.tail)

        linkedListLoop(null, table(bucket))

        queueLoop()
      }
    }

    queueLoop()
  }

  /**
   * Double the size of the internal table
   */
  private[this] def resize() {
    val oldTable = table
    table = new Array[Entry[A]](oldTable.size * 2)
    threshhold = computeThreshHold

    @tailrec
    def tableLoop(oldBucket: Int): Unit = if (oldBucket < oldTable.size) {
      @tailrec
      def linkedListLoop(entry: Entry[A]): Unit = entry match {
        case null => ()
        case _ => {
          val bucket = bucketFor(entry.hash)
          val oldNext = entry.tail
          entry.tail = table(bucket)
          table(bucket) = entry
          linkedListLoop(oldNext)
        }
      }
      linkedListLoop(oldTable(oldBucket))

      tableLoop(oldBucket + 1)
    }
    tableLoop(0)
  }

  // from scala.reflect.internal.Set, find an element or null if it isn't contained
  override def findEntry(elem: A): A = elem match {
    case null => throw new NullPointerException("WeakHashSet cannot hold nulls")
    case _    => {
      removeStaleEntries()
      val hash = elem.hashCode
      val bucket = bucketFor(hash)

      @tailrec
      def linkedListLoop(entry: Entry[A]): A = entry match {
        case null                    => null.asInstanceOf[A]
        case _                       => {
          val entryElem = entry.get
          if (elem == entryElem) entryElem
          else linkedListLoop(entry.tail)
        }
      }

      linkedListLoop(table(bucket))
    }
  }
  // add an element to this set unless it's already in there and return the element
  def findEntryOrUpdate(elem: A): A = elem match {
    case null => throw new NullPointerException("WeakHashSet cannot hold nulls")
    case _    => {
      removeStaleEntries()
      val hash = elem.hashCode
      val bucket = bucketFor(hash)
      val oldHead = table(bucket)

      def add() = {
        table(bucket) = new Entry(elem, hash, oldHead, queue)
        count += 1
        if (count > threshhold) resize()
        elem
      }

      @tailrec
      def linkedListLoop(entry: Entry[A]): A = entry match {
        case null                    => add()
        case _                       => {
          val entryElem = entry.get
          if (elem == entryElem) entryElem
          else linkedListLoop(entry.tail)
        }
      }

      linkedListLoop(oldHead)
    }
  }

  // add an element to this set unless it's already in there and return this set
  override def +(elem: A): this.type = elem match {
    case null => throw new NullPointerException("WeakHashSet cannot hold nulls")
    case _    => {
      removeStaleEntries()
      val hash = elem.hashCode
      val bucket = bucketFor(hash)
      val oldHead = table(bucket)

      def add() {
        table(bucket) = new Entry(elem, hash, oldHead, queue)
        count += 1
        if (count > threshhold) resize()
      }

      @tailrec
      def linkedListLoop(entry: Entry[A]): Unit = entry match {
        case null                      => add()
        case _ if (elem == entry.get) => ()
        case _                         => linkedListLoop(entry.tail)
      }

      linkedListLoop(oldHead)
      this
    }
  }

  def +=(elem: A) = this + elem

  // from scala.reflect.interanl.Set
  override def addEntry(x: A) { this += x }

  // remove an element from this set and return this set
  override def -(elem: A): this.type = elem match {
    case null => this
    case _ => {
      removeStaleEntries()
      val bucket = bucketFor(elem.hashCode)



      @tailrec
      def linkedListLoop(prevEntry: Entry[A], entry: Entry[A]): Unit = entry match {
        case null => ()
        case _ if (elem == entry.get) => remove(bucket, prevEntry, entry)
        case _ => linkedListLoop(entry, entry.tail)
      }

      linkedListLoop(null, table(bucket))
      this
    }
  }

  def -=(elem: A) = this - elem

  // empty this set
  override def clear(): Unit = {
    table = new Array[Entry[A]](table.size)
    threshhold = computeThreshHold
    count = 0

    // drain the queue - doesn't do anything because we're throwing away all the values anyway
    @tailrec def queueLoop(): Unit = if (queue.poll() != null) queueLoop()
    queueLoop()
  }

  // true if this set is empty
  override def empty: This = new WeakHashSet[A](initialCapacity, loadFactor)

  // the number of elements in this set
  override def size: Int = {
    removeStaleEntries()
    count
  }

  override def apply(x: A): Boolean = this contains x

  override def foreach[U](f: A => U): Unit = iterator foreach f

  // It has the `()` because iterator runs `removeStaleEntries()`
  override def toList(): List[A] = iterator.toList

  // Iterator over all the elements in this set in no particular order
  override def iterator: Iterator[A] = {
    removeStaleEntries()

    new Iterator[A] {

      /**
       * the bucket currently being examined. Initially it's set past the last bucket and will be decremented
       */
      private[this] var currentBucket: Int = table.size

      /**
       * the entry that was last examined
       */
      private[this] var entry: Entry[A] = null

      /**
       * the element that will be the result of the next call to next()
       */
      private[this] var lookaheadelement: A = null.asInstanceOf[A]

      @tailrec
      def hasNext: Boolean = {
        while (entry == null && currentBucket > 0) {
          currentBucket -= 1
          entry = table(currentBucket)
        }

        if (entry == null) false
        else {
          lookaheadelement = entry.get
          if (lookaheadelement == null) {
            // element null means the weakref has been cleared since we last did a removeStaleEntries(), move to the next entry
            entry = entry.tail
            hasNext
          } else {
            true
          }
        }
      }

      def next(): A = if (lookaheadelement == null)
        throw new IndexOutOfBoundsException("next on an empty iterator")
      else {
        val result = lookaheadelement
        lookaheadelement = null.asInstanceOf[A]
        entry = entry.tail
        result
      }
    }
  }

  /**
   * Diagnostic information about the internals of this set. Not normally
   * needed by ordinary code, but may be useful for diagnosing performance problems
   */
  private[util] class Diagnostics {
    /**
     * Verify that the internal structure of this hash set is fully consistent.
     * Throws an assertion error on any problem. In order for it to be reliable
     * the entries must be stable. If any are garbage collected during validation
     * then an assertion may inappropriately fire.
     */
    def fullyValidate: Unit = {
      var computedCount = 0
      var bucket = 0
      while (bucket < table.size) {
        var entry = table(bucket)
        while (entry != null) {
          assert(entry.get != null, s"$entry had a null value indicated that gc activity was happening during diagnostic validation or that a null value was inserted")
          computedCount += 1
          val cachedHash = entry.hash
          val realHash = entry.get.hashCode
          assert(cachedHash == realHash, s"for $entry cached hash was $cachedHash but should have been $realHash")
          val computedBucket = bucketFor(realHash)
          assert(computedBucket == bucket, s"for $entry the computed bucket was $computedBucket but should have been $bucket")

          entry = entry.tail
        }

        bucket += 1
      }

      assert(computedCount == count, s"The computed count was $computedCount but should have been $count")
    }

    /**
     *  Produces a diagnostic dump of the table that underlies this hash set.
     */
    def dump = table.deep

    /**
     * Number of buckets that hold collisions. Useful for diagnosing performance issues.
     */
    def collisionBucketsCount: Int =
      (table count (entry => entry != null && entry.tail != null))

    /**
     * Number of buckets that are occupied in this hash table.
     */
    def fullBucketsCount: Int =
      (table count (entry => entry != null))

    /**
     *  Number of buckets in the table
     */
    def bucketsCount: Int = table.size
  }

  private[util] def diagnostics = new Diagnostics
}

/**
 * Companion object for WeakHashSet
 */
object WeakHashSet {
  /**
   * A single entry in a WeakHashSet. It's a WeakReference plus a cached hash code and
   * a link to the next Entry in the same bucket
   */
  private class Entry[A](element: A, val hash:Int, var tail: Entry[A], queue: ReferenceQueue[A]) extends WeakReference[A](element, queue)

  val defaultInitialCapacity = 16
  val defaultLoadFactor = .75

  def apply[A <: AnyRef](initialCapacity: Int = WeakHashSet.defaultInitialCapacity, loadFactor: Double = WeakHashSet.defaultLoadFactor) = new WeakHashSet[A](initialCapacity, defaultLoadFactor)
}