summaryrefslogtreecommitdiff
path: root/src/library/scala/collection/parallel/mutable/ParHashTable.scala
blob: 2617685a3d4d8d5e7afa90e89f318d6113d0eea2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package scala.collection
package parallel.mutable




import collection.mutable.HashEntry
import collection.parallel.ParIterableIterator



/** Provides functionality for hash tables with linked list buckets,
 *  enriching the data structure by fulfilling certain requirements
 *  for their parallel construction and iteration.
 */
trait ParHashTable[K, Entry >: Null <: HashEntry[K, Entry]] extends collection.mutable.HashTable[K, Entry] {

  // always initialize size map
  if (!isSizeMapDefined) sizeMapInitAndRebuild

  /** A parallel iterator returning all the entries.
   */
  abstract class EntryIterator[T, +IterRepr <: ParIterableIterator[T]]
    (private var idx: Int, private val until: Int, private val totalsize: Int, private var es: Entry)
  extends ParIterableIterator[T] {
    private val itertable = table
    private var traversed = 0
    scan()

    def entry2item(e: Entry): T
    def newIterator(idxFrom: Int, idxUntil: Int, totalSize: Int, es: Entry): IterRepr

    def hasNext = es != null

    def next = {
      val res = es
      es = es.next
      scan()
      traversed += 1
      entry2item(res)
    }

    def scan() {
      while (es == null && idx < until) {
        es = itertable(idx).asInstanceOf[Entry]
        idx = idx + 1
      }
    }

    def remaining = totalsize - traversed

    def split: Seq[ParIterableIterator[T]] = if (remaining > 1) {
      if ((until - idx) > 1) {
        // there is at least one more slot for the next iterator
        // divide the rest of the table
        val divsz = (until - idx) / 2

        // second iterator params
        val sidx = idx + divsz
        val suntil = until
        val ses = itertable(sidx).asInstanceOf[Entry]
        val stotal = calcNumElems(sidx, suntil)

        // first iterator params
        val fidx = idx
        val funtil = idx + divsz
        val fes = es
        val ftotal = totalsize - stotal

        Seq(
          newIterator(fidx, funtil, ftotal, fes),
          newIterator(sidx, suntil, stotal, ses)
        )
      } else {
        // otherwise, this is the last entry in the table - all what remains is the chain
        // so split the rest of the chain
        val arr = convertToArrayBuffer(es)
        val arrpit = new collection.parallel.BufferIterator[T](arr, 0, arr.length, signalDelegate)
        arrpit.split
      }
    } else Seq(this.asInstanceOf[IterRepr])

    private def convertToArrayBuffer(chainhead: Entry): mutable.ArrayBuffer[T] = {
      var buff = mutable.ArrayBuffer[Entry]()
      var curr = chainhead
      while (curr != null) {
        buff += curr
        curr = curr.next
      }
      buff map { e => entry2item(e) }
    }

    private def calcNumElems(from: Int, until: Int) = {
      // find the first bucket
      val fbindex = from / sizeMapBucketSize

      // find the last bucket
      val lbindex = from / sizeMapBucketSize

      if (fbindex == lbindex) {
        // if first and last are the same, just count between `from` and `until`
        // return this count
        countElems(from, until)
      } else {
        // otherwise count in first, then count in last
        val fbuntil = ((fbindex + 1) * sizeMapBucketSize) min itertable.length
        val fbcount = countElems(from, fbuntil)
        val lbstart = lbindex * sizeMapBucketSize
        val lbcount = countElems(lbstart, until)

        // and finally count the elements in all the buckets between first and last using a sizemap
        val inbetween = countBucketSizes(fbindex + 1, lbindex)

        // return the sum
        fbcount + inbetween + lbcount
      }
    }

    private def countElems(from: Int, until: Int) = {
      var c = 0
      var idx = from
      var es: Entry = null
      while (idx < until) {
        es = itertable(idx).asInstanceOf[Entry]
        while (es ne null) {
          c += 1
          es = es.next
        }
        idx += 1
      }
      c
    }

    private def countBucketSizes(fromBucket: Int, untilBucket: Int) = {
      var c = 0
      var idx = fromBucket
      while (idx < untilBucket) {
        c += sizemap(idx)
        idx += 1
      }
      c
    }
  }

}