summaryrefslogtreecommitdiff
path: root/src/library/scala/collection/parallel/mutable/ParHashMap.scala
blob: fad7ddad597d89da9b67f46c05c796bd98ceaf31 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
/*                     __                                               *\
**     ________ ___   / /  ___     Scala API                            **
**    / __/ __// _ | / /  / _ |    (c) 2003-2011, LAMP/EPFL             **
**  __\ \/ /__/ __ |/ /__/ __ |    http://scala-lang.org/               **
** /____/\___/_/ |_/____/_/ | |                                         **
**                          |/                                          **
\*                                                                      */


package scala.collection.parallel
package mutable



import scala.collection.generic._
import scala.collection.mutable.DefaultEntry
import scala.collection.mutable.HashEntry
import scala.collection.mutable.HashTable
import scala.collection.mutable.UnrolledBuffer
import scala.collection.parallel.Task



/** A parallel hash map.
 *
 *  `ParHashMap` is a parallel map which internally keeps elements within a hash table.
 *  It uses chaining to resolve collisions.
 *
 *  @tparam K        type of the keys in the parallel hash map
 *  @tparam V        type of the values in the parallel hash map
 *
 *  @define Coll `ParHashMap`
 *  @define coll parallel hash map
 *
 *  @author Aleksandar Prokopec
 *  @see  [[http://docs.scala-lang.org/overviews/parallel-collections/concrete-parallel-collections.html#parallel_hash_tables Scala's Parallel Collections Library overview]]
 *  section on Parallel Hash Tables for more information.
 */
@SerialVersionUID(1L)
class ParHashMap[K, V] private[collection] (contents: HashTable.Contents[K, DefaultEntry[K, V]])
extends ParMap[K, V]
   with GenericParMapTemplate[K, V, ParHashMap]
   with ParMapLike[K, V, ParHashMap[K, V], scala.collection.mutable.HashMap[K, V]]
   with ParHashTable[K, DefaultEntry[K, V]]
   with Serializable
{
self =>
  initWithContents(contents)

  type Entry = scala.collection.mutable.DefaultEntry[K, V]

  def this() = this(null)

  override def mapCompanion: GenericParMapCompanion[ParHashMap] = ParHashMap

  override def empty: ParHashMap[K, V] = new ParHashMap[K, V]

  protected[this] override def newCombiner = ParHashMapCombiner[K, V]

  override def seq = new scala.collection.mutable.HashMap[K, V](hashTableContents)

  def splitter = new ParHashMapIterator(1, table.length, size, table(0).asInstanceOf[DefaultEntry[K, V]])

  override def size = tableSize

  override def clear() = clearTable()

  def get(key: K): Option[V] = {
    val e = findEntry(key)
    if (e eq null) None
    else Some(e.value)
  }

  def put(key: K, value: V): Option[V] = {
    val e = findOrAddEntry(key, value)
    if (e eq null) None
    else { val v = e.value; e.value = value; Some(v) }
  }

  def update(key: K, value: V): Unit = put(key, value)

  def remove(key: K): Option[V] = {
    val e = removeEntry(key)
    if (e ne null) Some(e.value)
    else None
  }

  def += (kv: (K, V)): this.type = {
    val e = findOrAddEntry(kv._1, kv._2)
    if (e ne null) e.value = kv._2
    this
  }

  def -=(key: K): this.type = { removeEntry(key); this }

  override def stringPrefix = "ParHashMap"

  class ParHashMapIterator(start: Int, untilIdx: Int, totalSize: Int, e: DefaultEntry[K, V])
  extends EntryIterator[(K, V), ParHashMapIterator](start, untilIdx, totalSize, e) {
    def entry2item(entry: DefaultEntry[K, V]) = (entry.key, entry.value);
    def newIterator(idxFrom: Int, idxUntil: Int, totalSz: Int, es: DefaultEntry[K, V]) =
      new ParHashMapIterator(idxFrom, idxUntil, totalSz, es)
  }

  protected def createNewEntry[V1](key: K, value: V1): Entry = {
    new Entry(key, value.asInstanceOf[V])
  }

  private def writeObject(out: java.io.ObjectOutputStream) {
    serializeTo(out, { entry =>
      out.writeObject(entry.key)
      out.writeObject(entry.value)
    })
  }

  private def readObject(in: java.io.ObjectInputStream) {
    init(in, createNewEntry(in.readObject().asInstanceOf[K], in.readObject()))
  }

  private[parallel] override def brokenInvariants = {
    // bucket by bucket, count elements
    val buckets = for (i <- 0 until (table.length / sizeMapBucketSize)) yield checkBucket(i)

    // check if each element is in the position corresponding to its key
    val elems = for (i <- 0 until table.length) yield checkEntry(i)

    buckets.flatMap(x => x) ++ elems.flatMap(x => x)
  }

  private def checkBucket(i: Int) = {
    def count(e: HashEntry[K, DefaultEntry[K, V]]): Int = if (e eq null) 0 else 1 + count(e.next)
    val expected = sizemap(i)
    val found = ((i * sizeMapBucketSize) until ((i + 1) * sizeMapBucketSize)).foldLeft(0) {
      (acc, c) => acc + count(table(c))
    }
    if (found != expected) List("Found " + found + " elements, while sizemap showed " + expected)
    else Nil
  }

  private def checkEntry(i: Int) = {
    def check(e: HashEntry[K, DefaultEntry[K, V]]): List[String] = if (e eq null) Nil else
      if (index(elemHashCode(e.key)) == i) check(e.next)
      else ("Element " + e.key + " at " + i + " with " + elemHashCode(e.key) + " maps to " + index(elemHashCode(e.key))) :: check(e.next)
    check(table(i))
  }

}


/** $factoryInfo
 *  @define Coll `mutable.ParHashMap`
 *  @define coll parallel hash map
 */
object ParHashMap extends ParMapFactory[ParHashMap] {
  var iters = 0

  def empty[K, V]: ParHashMap[K, V] = new ParHashMap[K, V]

  def newCombiner[K, V]: Combiner[(K, V), ParHashMap[K, V]] = ParHashMapCombiner.apply[K, V]

  implicit def canBuildFrom[K, V]: CanCombineFrom[Coll, (K, V), ParHashMap[K, V]] = new CanCombineFromMap[K, V]
}


private[mutable] abstract class ParHashMapCombiner[K, V](private val tableLoadFactor: Int)
extends scala.collection.parallel.BucketCombiner[(K, V), ParHashMap[K, V], DefaultEntry[K, V], ParHashMapCombiner[K, V]](ParHashMapCombiner.numblocks)
   with scala.collection.mutable.HashTable.HashUtils[K]
{
  private var mask = ParHashMapCombiner.discriminantmask
  private var nonmasklen = ParHashMapCombiner.nonmasklength
  private var seedvalue = 27

  def +=(elem: (K, V)) = {
    sz += 1
    val hc = improve(elemHashCode(elem._1), seedvalue)
    val pos = (hc >>> nonmasklen)
    if (buckets(pos) eq null) {
      // initialize bucket
      buckets(pos) = new UnrolledBuffer[DefaultEntry[K, V]]()
    }
    // add to bucket
    buckets(pos) += new DefaultEntry(elem._1, elem._2)
    this
  }

  def result: ParHashMap[K, V] = if (size >= (ParHashMapCombiner.numblocks * sizeMapBucketSize)) { // 1024
    // construct table
    val table = new AddingHashTable(size, tableLoadFactor, seedvalue)
    val bucks = buckets.map(b => if (b ne null) b.headPtr else null)
    val insertcount = combinerTaskSupport.executeAndWaitResult(new FillBlocks(bucks, table, 0, bucks.length))
    table.setSize(insertcount)
    // TODO compare insertcount and size to see if compression is needed
    val c = table.hashTableContents
    new ParHashMap(c)
  } else {
    // construct a normal table and fill it sequentially
    // TODO parallelize by keeping separate sizemaps and merging them
    object table extends HashTable[K, DefaultEntry[K, V]] {
      type Entry = DefaultEntry[K, V]
      def insertEntry(e: Entry) { super.findOrAddEntry(e.key, e) }
      def createNewEntry[E](key: K, entry: E): Entry = entry.asInstanceOf[Entry]
      sizeMapInit(table.length)
    }
    var i = 0
    while (i < ParHashMapCombiner.numblocks) {
      if (buckets(i) ne null) {
        for (elem <- buckets(i)) table.insertEntry(elem)
      }
      i += 1
    }
    new ParHashMap(table.hashTableContents)
  }

  /* classes */

  /** A hash table which will never resize itself. Knowing the number of elements in advance,
   *  it allocates the table of the required size when created.
   *
   *  Entries are added using the `insertEntry` method. This method checks whether the element
   *  exists and updates the size map. It returns false if the key was already in the table,
   *  and true if the key was successfully inserted. It does not update the number of elements
   *  in the table.
   */
  private[ParHashMapCombiner] class AddingHashTable(numelems: Int, lf: Int, _seedvalue: Int) extends HashTable[K, DefaultEntry[K, V]] {
    import HashTable._
    _loadFactor = lf
    table = new Array[HashEntry[K, DefaultEntry[K, V]]](capacity(sizeForThreshold(_loadFactor, numelems)))
    tableSize = 0
    seedvalue = _seedvalue
    threshold = newThreshold(_loadFactor, table.length)
    sizeMapInit(table.length)
    def setSize(sz: Int) = tableSize = sz
    def insertEntry(/*block: Int, */e: DefaultEntry[K, V]) = {
      var h = index(elemHashCode(e.key))
      // assertCorrectBlock(h, block)
      var olde = table(h).asInstanceOf[DefaultEntry[K, V]]

      // check if key already exists
      var ce = olde
      while (ce ne null) {
        if (ce.key == e.key) {
          h = -1
          ce = null
        } else ce = ce.next
      }

      // if key does not already exist
      if (h != -1) {
        e.next = olde
        table(h) = e
        nnSizeMapAdd(h)
        true
      } else false
    }
    private def assertCorrectBlock(h: Int, block: Int) {
      val blocksize = table.length / (1 << ParHashMapCombiner.discriminantbits)
      if (!(h >= block * blocksize && h < (block + 1) * blocksize)) {
        println("trying to put " + h + " into block no.: " + block + ", range: [" + block * blocksize + ", " + (block + 1) * blocksize + ">")
        assert(h >= block * blocksize && h < (block + 1) * blocksize)
      }
    }
    protected def createNewEntry[X](key: K, x: X) = ???
  }

  /* tasks */

  import UnrolledBuffer.Unrolled

  class FillBlocks(buckets: Array[Unrolled[DefaultEntry[K, V]]], table: AddingHashTable, offset: Int, howmany: Int)
  extends Task[Int, FillBlocks] {
    var result = Int.MinValue
    def leaf(prev: Option[Int]) = {
      var i = offset
      val until = offset + howmany
      result = 0
      while (i < until) {
        result += fillBlock(i, buckets(i))
        i += 1
      }
    }
    private def fillBlock(block: Int, elems: Unrolled[DefaultEntry[K, V]]) = {
      var insertcount = 0
      var unrolled = elems
      var i = 0
      val t = table
      while (unrolled ne null) {
        val chunkarr = unrolled.array
        val chunksz = unrolled.size
        while (i < chunksz) {
          val elem = chunkarr(i)
          // assertCorrectBlock(block, elem.key)
          if (t.insertEntry(elem)) insertcount += 1
          i += 1
        }
        i = 0
        unrolled = unrolled.next
      }
      insertcount
    }
    private def assertCorrectBlock(block: Int, k: K) {
      val hc = improve(elemHashCode(k), seedvalue)
      if ((hc >>> nonmasklen) != block) {
        println(hc + " goes to " + (hc >>> nonmasklen) + ", while expected block is " + block)
        assert((hc >>> nonmasklen) == block)
      }
    }
    def split = {
      val fp = howmany / 2
      List(new FillBlocks(buckets, table, offset, fp), new FillBlocks(buckets, table, offset + fp, howmany - fp))
    }
    override def merge(that: FillBlocks) {
      this.result += that.result
    }
    def shouldSplitFurther = howmany > scala.collection.parallel.thresholdFromSize(ParHashMapCombiner.numblocks, combinerTaskSupport.parallelismLevel)
  }

}


private[parallel] object ParHashMapCombiner {
  private[mutable] val discriminantbits = 5
  private[mutable] val numblocks = 1 << discriminantbits
  private[mutable] val discriminantmask = ((1 << discriminantbits) - 1);
  private[mutable] val nonmasklength = 32 - discriminantbits

  def apply[K, V] = new ParHashMapCombiner[K, V](HashTable.defaultLoadFactor) {} // was: with EnvironmentPassingCombiner[(K, V), ParHashMap[K, V]]
}