summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/library/scala/collection/mutable/FlatHashTable.scala99
-rw-r--r--src/library/scala/collection/parallel/mutable/ParHashSet.scala13
2 files changed, 81 insertions, 31 deletions
diff --git a/src/library/scala/collection/mutable/FlatHashTable.scala b/src/library/scala/collection/mutable/FlatHashTable.scala
index 0740d97e09..f3fb6738eb 100644
--- a/src/library/scala/collection/mutable/FlatHashTable.scala
+++ b/src/library/scala/collection/mutable/FlatHashTable.scala
@@ -24,7 +24,7 @@ package mutable
trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
import FlatHashTable._
- private final val tableDebug = false
+ private final def tableDebug = false
@transient private[collection] var _loadFactor = defaultLoadFactor
@@ -43,11 +43,19 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
/** The array keeping track of number of elements in 32 element blocks.
*/
@transient protected var sizemap: Array[Int] = null
-
+
+ @transient var seedvalue: Int = tableSizeSeed
+
import HashTable.powerOfTwo
+
protected def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize)
+
private def initialCapacity = capacity(initialSize)
-
+
+ protected def randomSeed = seedGenerator.get.nextInt()
+
+ protected def tableSizeSeed = Integer.bitCount(table.length - 1)
+
/**
* Initializes the collection from the input stream. `f` will be called for each element
* read from the input stream in the order determined by the stream. This is useful for
@@ -57,23 +65,25 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
*/
private[collection] def init(in: java.io.ObjectInputStream, f: A => Unit) {
in.defaultReadObject
-
- _loadFactor = in.readInt
+
+ _loadFactor = in.readInt()
assert(_loadFactor > 0)
-
- val size = in.readInt
+
+ val size = in.readInt()
tableSize = 0
assert(size >= 0)
-
+
table = new Array(capacity(sizeForThreshold(size, _loadFactor)))
threshold = newThreshold(_loadFactor, table.size)
-
- val smDefined = in.readBoolean
+
+ seedvalue = in.readInt()
+
+ val smDefined = in.readBoolean()
if (smDefined) sizeMapInit(table.length) else sizemap = null
-
+
var index = 0
while (index < size) {
- val elem = in.readObject.asInstanceOf[A]
+ val elem = in.readObject().asInstanceOf[A]
f(elem)
addEntry(elem)
index += 1
@@ -89,6 +99,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
out.defaultWriteObject
out.writeInt(_loadFactor)
out.writeInt(tableSize)
+ out.writeInt(seedvalue)
out.writeBoolean(isSizeMapDefined)
iterator.foreach(out.writeObject)
}
@@ -125,6 +136,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
if (entry == elem) return false
h = (h + 1) % table.length
entry = table(h)
+ //Statistics.collisions += 1
}
table(h) = elem.asInstanceOf[AnyRef]
tableSize = tableSize + 1
@@ -185,6 +197,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
table = new Array[AnyRef](table.length * 2)
tableSize = 0
nnSizeMapReset(table.length)
+ seedvalue = tableSizeSeed
threshold = newThreshold(_loadFactor, table.length)
var i = 0
while (i < oldtable.length) {
@@ -280,10 +293,24 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
/* End of size map handling code */
protected final def index(hcode: Int) = {
+ // version 1 (no longer used - did not work with parallel hash tables)
// improve(hcode) & (table.length - 1)
- val improved = improve(hcode)
+
+ // version 2 (allows for parallel hash table construction)
+ val improved = improve(hcode, seedvalue)
val ones = table.length - 1
(improved >>> (32 - java.lang.Integer.bitCount(ones))) & ones
+
+ // version 3 (solves SI-5293 in most cases, but such a case would still arise for parallel hash tables)
+ // val hc = improve(hcode)
+ // val bbp = blockbitpos
+ // val ones = table.length - 1
+ // val needed = Integer.bitCount(ones)
+ // val blockbits = ((hc >>> bbp) & 0x1f) << (needed - 5)
+ // val rest = ((hc >>> (bbp + 5)) << bbp) | (((1 << bbp) - 1) & hc)
+ // val restmask = (1 << (needed - 5)) - 1
+ // val improved = blockbits | (rest & restmask)
+ // improved
}
protected def clearTable() {
@@ -298,6 +325,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
table,
tableSize,
threshold,
+ seedvalue,
sizemap
)
@@ -307,6 +335,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
table = c.table
tableSize = c.tableSize
threshold = c.threshold
+ seedvalue = c.seedvalue
sizemap = c.sizemap
}
if (alwaysInitSizeMap && sizemap == null) sizeMapInitAndRebuild
@@ -315,21 +344,30 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
}
-
private[collection] object FlatHashTable {
-
+
+ /** Creates a specific seed to improve hashcode of a hash table instance
+ * and ensure that iteration order vulnerabilities are not 'felt' in other
+ * hash tables.
+ *
+ * See SI-5293.
+ */
+ final def seedGenerator = new ThreadLocal[util.Random] {
+ override def initialValue = new util.Random
+ }
+
/** The load factor for the hash table; must be < 500 (0.5)
*/
- private[collection] def defaultLoadFactor: Int = 450
- private[collection] final def loadFactorDenum = 1000
+ def defaultLoadFactor: Int = 450
+ final def loadFactorDenum = 1000
/** The initial size of the hash table.
*/
- private[collection] def initialSize: Int = 16
+ def initialSize: Int = 32
- private[collection] def sizeForThreshold(size: Int, _loadFactor: Int) = (size.toLong * loadFactorDenum / _loadFactor).toInt
+ def sizeForThreshold(size: Int, _loadFactor: Int) = math.max(32, (size.toLong * loadFactorDenum / _loadFactor).toInt)
- private[collection] def newThreshold(_loadFactor: Int, size: Int) = {
+ def newThreshold(_loadFactor: Int, size: Int) = {
val lf = _loadFactor
assert(lf < (loadFactorDenum / 2), "loadFactor too large; must be < 0.5")
(size.toLong * lf / loadFactorDenum ).toInt
@@ -340,6 +378,7 @@ private[collection] object FlatHashTable {
val table: Array[AnyRef],
val tableSize: Int,
val threshold: Int,
+ val seedvalue: Int,
val sizemap: Array[Int]
)
@@ -352,16 +391,24 @@ private[collection] object FlatHashTable {
if (elem == null) throw new IllegalArgumentException("Flat hash tables cannot contain null elements.")
else elem.hashCode()
- protected final def improve(hcode: Int) = {
- // var h: Int = hcode + ~(hcode << 9)
- // h = h ^ (h >>> 14)
- // h = h + (h << 4)
- // h ^ (h >>> 10)
+ protected final def improve(hcode: Int, seed: Int) = {
+ //var h: Int = hcode + ~(hcode << 9)
+ //h = h ^ (h >>> 14)
+ //h = h + (h << 4)
+ //h ^ (h >>> 10)
+
var i = hcode * 0x9e3775cd
i = java.lang.Integer.reverseBytes(i)
- i * 0x9e3775cd
+ val improved = i * 0x9e3775cd
+
+ // for the remainder, see SI-5293
+ // to ensure that different bits are used for different hash tables, we have to rotate based on the seed
+ val rotation = seed % 32
+ val rotated = (improved >>> rotation) | (improved << (32 - rotation))
+ rotated
}
}
}
+
diff --git a/src/library/scala/collection/parallel/mutable/ParHashSet.scala b/src/library/scala/collection/parallel/mutable/ParHashSet.scala
index 9dbc7dc6c4..7763cdf318 100644
--- a/src/library/scala/collection/parallel/mutable/ParHashSet.scala
+++ b/src/library/scala/collection/parallel/mutable/ParHashSet.scala
@@ -119,10 +119,11 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
import collection.parallel.tasksupport._
private var mask = ParHashSetCombiner.discriminantmask
private var nonmasklen = ParHashSetCombiner.nonmasklength
-
+ private var seedvalue = 27
+
def +=(elem: T) = {
sz += 1
- val hc = improve(elemHashCode(elem))
+ val hc = improve(elemHashCode(elem), seedvalue)
val pos = hc >>> nonmasklen
if (buckets(pos) eq null) {
// initialize bucket
@@ -140,7 +141,7 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
private def parPopulate: FlatHashTable.Contents[T] = {
// construct it in parallel
- val table = new AddingFlatHashTable(size, tableLoadFactor)
+ val table = new AddingFlatHashTable(size, tableLoadFactor, seedvalue)
val (inserted, leftovers) = executeAndWaitResult(new FillBlocks(buckets, table, 0, buckets.length))
var leftinserts = 0
for (elem <- leftovers) leftinserts += table.insertEntry(0, table.tableLength, elem.asInstanceOf[T])
@@ -153,6 +154,7 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
// TODO parallelize by keeping separate size maps and merging them
val tbl = new FlatHashTable[T] {
sizeMapInit(table.length)
+ seedvalue = ParHashSetCombiner.this.seedvalue
}
for {
buffer <- buckets;
@@ -168,13 +170,13 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
* it has to take and allocates the underlying hash table in advance.
* Elements can only be added to it. The final size has to be adjusted manually.
* It is internal to `ParHashSet` combiners.
- *
*/
- class AddingFlatHashTable(numelems: Int, lf: Int) extends FlatHashTable[T] {
+ class AddingFlatHashTable(numelems: Int, lf: Int, inseedvalue: Int) extends FlatHashTable[T] {
_loadFactor = lf
table = new Array[AnyRef](capacity(FlatHashTable.sizeForThreshold(numelems, _loadFactor)))
tableSize = 0
threshold = FlatHashTable.newThreshold(_loadFactor, table.length)
+ seedvalue = inseedvalue
sizeMapInit(table.length)
override def toString = "AFHT(%s)".format(table.length)
@@ -310,6 +312,7 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
}
+
private[parallel] object ParHashSetCombiner {
private[mutable] val discriminantbits = 5
private[mutable] val numblocks = 1 << discriminantbits