diff options
-rw-r--r-- | src/library/scala/collection/mutable/HashTable.scala | 32 | ||||
-rw-r--r-- | src/library/scala/collection/parallel/mutable/ParHashMap.scala | 10 | ||||
-rw-r--r-- | test/files/jvm/serialization.check | 8 | ||||
-rw-r--r-- | test/files/run/t5293-map.scala | 88 |
4 files changed, 122 insertions, 16 deletions
diff --git a/src/library/scala/collection/mutable/HashTable.scala b/src/library/scala/collection/mutable/HashTable.scala index cdf1b78f29..5b3e07b826 100644 --- a/src/library/scala/collection/mutable/HashTable.scala +++ b/src/library/scala/collection/mutable/HashTable.scala @@ -52,6 +52,10 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU */ @transient protected var sizemap: Array[Int] = null + @transient var seedvalue: Int = tableSizeSeed + + protected def tableSizeSeed = Integer.bitCount(table.length - 1) + protected def initialSize: Int = HashTable.initialSize private def lastPopulatedIndex = { @@ -70,14 +74,16 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU private[collection] def init[B](in: java.io.ObjectInputStream, f: (A, B) => Entry) { in.defaultReadObject - _loadFactor = in.readInt + _loadFactor = in.readInt() assert(_loadFactor > 0) - val size = in.readInt + val size = in.readInt() tableSize = 0 assert(size >= 0) - - val smDefined = in.readBoolean + + seedvalue = in.readInt() + + val smDefined = in.readBoolean() table = new Array(capacity(sizeForThreshold(_loadFactor, size))) threshold = newThreshold(_loadFactor, table.size) @@ -86,7 +92,7 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU var index = 0 while (index < size) { - addEntry(f(in.readObject.asInstanceOf[A], in.readObject.asInstanceOf[B])) + addEntry(f(in.readObject().asInstanceOf[A], in.readObject().asInstanceOf[B])) index += 1 } } @@ -103,6 +109,7 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU out.defaultWriteObject out.writeInt(_loadFactor) out.writeInt(tableSize) + out.writeInt(seedvalue) out.writeBoolean(isSizeMapDefined) foreachEntry { entry => out.writeObject(entry.key) @@ -314,7 +321,7 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU // this is of crucial importance when populating the table in parallel protected final def index(hcode: Int) = { val ones = table.length - 1 - val improved = improve(hcode) + val improved = improve(hcode, seedvalue) val shifted = (improved >> (32 - java.lang.Integer.bitCount(ones))) & ones shifted } @@ -325,6 +332,7 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU table = c.table tableSize = c.tableSize threshold = c.threshold + seedvalue = c.seedvalue sizemap = c.sizemap } if (alwaysInitSizeMap && sizemap == null) sizeMapInitAndRebuild @@ -335,6 +343,7 @@ trait HashTable[A, Entry >: Null <: HashEntry[A, Entry]] extends HashTable.HashU table, tableSize, threshold, + seedvalue, sizemap ) } @@ -368,7 +377,7 @@ private[collection] object HashTable { protected def elemHashCode(key: KeyType) = key.## - protected final def improve(hcode: Int) = { + protected final def improve(hcode: Int, seed: Int) = { /* Murmur hash * m = 0x5bd1e995 * r = 24 @@ -396,7 +405,7 @@ private[collection] object HashTable { * */ var i = hcode * 0x9e3775cd i = java.lang.Integer.reverseBytes(i) - i * 0x9e3775cd + i = i * 0x9e3775cd // a slower alternative for byte reversal: // i = (i << 16) | (i >> 16) // i = ((i >> 8) & 0x00ff00ff) | ((i << 8) & 0xff00ff00) @@ -420,6 +429,11 @@ private[collection] object HashTable { // h = h ^ (h >>> 14) // h = h + (h << 4) // h ^ (h >>> 10) + + // the rest of the computation is due to SI-5293 + val rotation = seed % 32 + val rotated = (i >>> rotation) | (i << (32 - rotation)) + rotated } } @@ -442,6 +456,7 @@ private[collection] object HashTable { val table: Array[HashEntry[A, Entry]], val tableSize: Int, val threshold: Int, + val seedvalue: Int, val sizemap: Array[Int] ) { import collection.DebugUtils._ @@ -452,6 +467,7 @@ private[collection] object HashTable { append("Table: [" + arrayString(table, 0, table.length) + "]") append("Table size: " + tableSize) append("Load factor: " + loadFactor) + append("Seedvalue: " + seedvalue) append("Threshold: " + threshold) append("Sizemap: [" + arrayString(sizemap, 0, sizemap.length) + "]") } diff --git a/src/library/scala/collection/parallel/mutable/ParHashMap.scala b/src/library/scala/collection/parallel/mutable/ParHashMap.scala index 15ffd3fdd2..21a5b05749 100644 --- a/src/library/scala/collection/parallel/mutable/ParHashMap.scala +++ b/src/library/scala/collection/parallel/mutable/ParHashMap.scala @@ -160,10 +160,11 @@ extends collection.parallel.BucketCombiner[(K, V), ParHashMap[K, V], DefaultEntr import collection.parallel.tasksupport._ private var mask = ParHashMapCombiner.discriminantmask private var nonmasklen = ParHashMapCombiner.nonmasklength + private var seedvalue = 27 def +=(elem: (K, V)) = { sz += 1 - val hc = improve(elemHashCode(elem._1)) + val hc = improve(elemHashCode(elem._1), seedvalue) val pos = (hc >>> nonmasklen) if (buckets(pos) eq null) { // initialize bucket @@ -176,7 +177,7 @@ extends collection.parallel.BucketCombiner[(K, V), ParHashMap[K, V], DefaultEntr def result: ParHashMap[K, V] = if (size >= (ParHashMapCombiner.numblocks * sizeMapBucketSize)) { // 1024 // construct table - val table = new AddingHashTable(size, tableLoadFactor) + val table = new AddingHashTable(size, tableLoadFactor, seedvalue) val bucks = buckets.map(b => if (b ne null) b.headPtr else null) val insertcount = executeAndWaitResult(new FillBlocks(bucks, table, 0, bucks.length)) table.setSize(insertcount) @@ -210,11 +211,12 @@ extends collection.parallel.BucketCombiner[(K, V), ParHashMap[K, V], DefaultEntr * and true if the key was successfully inserted. It does not update the number of elements * in the table. */ - private[ParHashMapCombiner] class AddingHashTable(numelems: Int, lf: Int) extends HashTable[K, DefaultEntry[K, V]] { + private[ParHashMapCombiner] class AddingHashTable(numelems: Int, lf: Int, _seedvalue: Int) extends HashTable[K, DefaultEntry[K, V]] { import HashTable._ _loadFactor = lf table = new Array[HashEntry[K, DefaultEntry[K, V]]](capacity(sizeForThreshold(_loadFactor, numelems))) tableSize = 0 + seedvalue = _seedvalue threshold = newThreshold(_loadFactor, table.length) sizeMapInit(table.length) def setSize(sz: Int) = tableSize = sz @@ -285,7 +287,7 @@ extends collection.parallel.BucketCombiner[(K, V), ParHashMap[K, V], DefaultEntr insertcount } private def assertCorrectBlock(block: Int, k: K) { - val hc = improve(elemHashCode(k)) + val hc = improve(elemHashCode(k), seedvalue) if ((hc >>> nonmasklen) != block) { println(hc + " goes to " + (hc >>> nonmasklen) + ", while expected block is " + block) assert((hc >>> nonmasklen) == block) diff --git a/test/files/jvm/serialization.check b/test/files/jvm/serialization.check index 67b77639a2..81b68f0f5d 100644 --- a/test/files/jvm/serialization.check +++ b/test/files/jvm/serialization.check @@ -156,8 +156,8 @@ x = BitSet(0, 8, 9) y = BitSet(0, 8, 9) x equals y: true, y equals x: true -x = Map(C -> 3, B -> 2, A -> 1) -y = Map(C -> 3, A -> 1, B -> 2) +x = Map(A -> 1, C -> 3, B -> 2) +y = Map(A -> 1, C -> 3, B -> 2) x equals y: true, y equals x: true x = Set(buffers, title, layers) @@ -283,8 +283,8 @@ x = ParArray(abc, def, etc) y = ParArray(abc, def, etc) x equals y: true, y equals x: true -x = ParHashMap(1 -> 2, 2 -> 4) -y = ParHashMap(1 -> 2, 2 -> 4) +x = ParHashMap(2 -> 4, 1 -> 2) +y = ParHashMap(2 -> 4, 1 -> 2) x equals y: true, y equals x: true x = ParCtrie(1 -> 2, 2 -> 4) diff --git a/test/files/run/t5293-map.scala b/test/files/run/t5293-map.scala new file mode 100644 index 0000000000..9e186894fc --- /dev/null +++ b/test/files/run/t5293-map.scala @@ -0,0 +1,88 @@ + + + +import scala.collection.JavaConverters._ + + + +object Test extends App { + + def bench(label: String)(body: => Unit): Long = { + val start = System.nanoTime + + 0.until(10).foreach(_ => body) + + val end = System.nanoTime + + //println("%s: %s ms".format(label, (end - start) / 1000.0 / 1000.0)) + + end - start + } + + def benchJava(values: java.util.Map[Int, Int]) = { + bench("Java Map") { + val m = new java.util.HashMap[Int, Int] + + m.putAll(values) + } + } + + def benchScala(values: Iterable[(Int, Int)]) = { + bench("Scala Map") { + val m = new scala.collection.mutable.HashMap[Int, Int] + + m ++= values + } + } + + def benchScalaSorted(values: Iterable[(Int, Int)]) = { + bench("Scala Map sorted") { + val m = new scala.collection.mutable.HashMap[Int, Int] + + m ++= values.toArray.sorted + } + } + + def benchScalaPar(values: Iterable[(Int, Int)]) = { + bench("Scala ParMap") { + val m = new scala.collection.parallel.mutable.ParHashMap[Int, Int] map { x => x } + + m ++= values + } + } + + val total = 50000 + val values = (0 until total) zip (0 until total) + val map = scala.collection.mutable.HashMap.empty[Int, Int] + + map ++= values + + // warmup + for (x <- 0 until 5) { + benchJava(map.asJava) + benchScala(map) + benchScalaPar(map) + benchJava(map.asJava) + benchScala(map) + benchScalaPar(map) + } + + val javamap = benchJava(map.asJava) + val scalamap = benchScala(map) + val scalaparmap = benchScalaPar(map) + + // println(javamap) + // println(scalamap) + // println(scalaparmap) + + assert(scalamap < (javamap * 4)) + assert(scalaparmap < (javamap * 4)) +} + + + + + + + + |