summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/library/scala/collection/mutable/FlatHashTable.scala99
-rw-r--r--src/library/scala/collection/parallel/mutable/ParHashSet.scala13
-rw-r--r--test/files/jvm/serialization.check8
-rw-r--r--test/files/run/t5293.scala83
4 files changed, 168 insertions, 35 deletions
diff --git a/src/library/scala/collection/mutable/FlatHashTable.scala b/src/library/scala/collection/mutable/FlatHashTable.scala
index 0740d97e09..f3fb6738eb 100644
--- a/src/library/scala/collection/mutable/FlatHashTable.scala
+++ b/src/library/scala/collection/mutable/FlatHashTable.scala
@@ -24,7 +24,7 @@ package mutable
trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
import FlatHashTable._
- private final val tableDebug = false
+ private final def tableDebug = false
@transient private[collection] var _loadFactor = defaultLoadFactor
@@ -43,11 +43,19 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
/** The array keeping track of number of elements in 32 element blocks.
*/
@transient protected var sizemap: Array[Int] = null
-
+
+ @transient var seedvalue: Int = tableSizeSeed
+
import HashTable.powerOfTwo
+
protected def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize)
+
private def initialCapacity = capacity(initialSize)
-
+
+ protected def randomSeed = seedGenerator.get.nextInt()
+
+ protected def tableSizeSeed = Integer.bitCount(table.length - 1)
+
/**
* Initializes the collection from the input stream. `f` will be called for each element
* read from the input stream in the order determined by the stream. This is useful for
@@ -57,23 +65,25 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
*/
private[collection] def init(in: java.io.ObjectInputStream, f: A => Unit) {
in.defaultReadObject
-
- _loadFactor = in.readInt
+
+ _loadFactor = in.readInt()
assert(_loadFactor > 0)
-
- val size = in.readInt
+
+ val size = in.readInt()
tableSize = 0
assert(size >= 0)
-
+
table = new Array(capacity(sizeForThreshold(size, _loadFactor)))
threshold = newThreshold(_loadFactor, table.size)
-
- val smDefined = in.readBoolean
+
+ seedvalue = in.readInt()
+
+ val smDefined = in.readBoolean()
if (smDefined) sizeMapInit(table.length) else sizemap = null
-
+
var index = 0
while (index < size) {
- val elem = in.readObject.asInstanceOf[A]
+ val elem = in.readObject().asInstanceOf[A]
f(elem)
addEntry(elem)
index += 1
@@ -89,6 +99,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
out.defaultWriteObject
out.writeInt(_loadFactor)
out.writeInt(tableSize)
+ out.writeInt(seedvalue)
out.writeBoolean(isSizeMapDefined)
iterator.foreach(out.writeObject)
}
@@ -125,6 +136,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
if (entry == elem) return false
h = (h + 1) % table.length
entry = table(h)
+ //Statistics.collisions += 1
}
table(h) = elem.asInstanceOf[AnyRef]
tableSize = tableSize + 1
@@ -185,6 +197,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
table = new Array[AnyRef](table.length * 2)
tableSize = 0
nnSizeMapReset(table.length)
+ seedvalue = tableSizeSeed
threshold = newThreshold(_loadFactor, table.length)
var i = 0
while (i < oldtable.length) {
@@ -280,10 +293,24 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
/* End of size map handling code */
protected final def index(hcode: Int) = {
+ // version 1 (no longer used - did not work with parallel hash tables)
// improve(hcode) & (table.length - 1)
- val improved = improve(hcode)
+
+ // version 2 (allows for parallel hash table construction)
+ val improved = improve(hcode, seedvalue)
val ones = table.length - 1
(improved >>> (32 - java.lang.Integer.bitCount(ones))) & ones
+
+ // version 3 (solves SI-5293 in most cases, but such a case would still arise for parallel hash tables)
+ // val hc = improve(hcode)
+ // val bbp = blockbitpos
+ // val ones = table.length - 1
+ // val needed = Integer.bitCount(ones)
+ // val blockbits = ((hc >>> bbp) & 0x1f) << (needed - 5)
+ // val rest = ((hc >>> (bbp + 5)) << bbp) | (((1 << bbp) - 1) & hc)
+ // val restmask = (1 << (needed - 5)) - 1
+ // val improved = blockbits | (rest & restmask)
+ // improved
}
protected def clearTable() {
@@ -298,6 +325,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
table,
tableSize,
threshold,
+ seedvalue,
sizemap
)
@@ -307,6 +335,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
table = c.table
tableSize = c.tableSize
threshold = c.threshold
+ seedvalue = c.seedvalue
sizemap = c.sizemap
}
if (alwaysInitSizeMap && sizemap == null) sizeMapInitAndRebuild
@@ -315,21 +344,30 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
}
-
private[collection] object FlatHashTable {
-
+
+ /** Creates a specific seed to improve hashcode of a hash table instance
+ * and ensure that iteration order vulnerabilities are not 'felt' in other
+ * hash tables.
+ *
+ * See SI-5293.
+ */
+ final def seedGenerator = new ThreadLocal[util.Random] {
+ override def initialValue = new util.Random
+ }
+
/** The load factor for the hash table; must be < 500 (0.5)
*/
- private[collection] def defaultLoadFactor: Int = 450
- private[collection] final def loadFactorDenum = 1000
+ def defaultLoadFactor: Int = 450
+ final def loadFactorDenum = 1000
/** The initial size of the hash table.
*/
- private[collection] def initialSize: Int = 16
+ def initialSize: Int = 32
- private[collection] def sizeForThreshold(size: Int, _loadFactor: Int) = (size.toLong * loadFactorDenum / _loadFactor).toInt
+ def sizeForThreshold(size: Int, _loadFactor: Int) = math.max(32, (size.toLong * loadFactorDenum / _loadFactor).toInt)
- private[collection] def newThreshold(_loadFactor: Int, size: Int) = {
+ def newThreshold(_loadFactor: Int, size: Int) = {
val lf = _loadFactor
assert(lf < (loadFactorDenum / 2), "loadFactor too large; must be < 0.5")
(size.toLong * lf / loadFactorDenum ).toInt
@@ -340,6 +378,7 @@ private[collection] object FlatHashTable {
val table: Array[AnyRef],
val tableSize: Int,
val threshold: Int,
+ val seedvalue: Int,
val sizemap: Array[Int]
)
@@ -352,16 +391,24 @@ private[collection] object FlatHashTable {
if (elem == null) throw new IllegalArgumentException("Flat hash tables cannot contain null elements.")
else elem.hashCode()
- protected final def improve(hcode: Int) = {
- // var h: Int = hcode + ~(hcode << 9)
- // h = h ^ (h >>> 14)
- // h = h + (h << 4)
- // h ^ (h >>> 10)
+ protected final def improve(hcode: Int, seed: Int) = {
+ //var h: Int = hcode + ~(hcode << 9)
+ //h = h ^ (h >>> 14)
+ //h = h + (h << 4)
+ //h ^ (h >>> 10)
+
var i = hcode * 0x9e3775cd
i = java.lang.Integer.reverseBytes(i)
- i * 0x9e3775cd
+ val improved = i * 0x9e3775cd
+
+ // for the remainder, see SI-5293
+ // to ensure that different bits are used for different hash tables, we have to rotate based on the seed
+ val rotation = seed % 32
+ val rotated = (improved >>> rotation) | (improved << (32 - rotation))
+ rotated
}
}
}
+
diff --git a/src/library/scala/collection/parallel/mutable/ParHashSet.scala b/src/library/scala/collection/parallel/mutable/ParHashSet.scala
index 9dbc7dc6c4..7763cdf318 100644
--- a/src/library/scala/collection/parallel/mutable/ParHashSet.scala
+++ b/src/library/scala/collection/parallel/mutable/ParHashSet.scala
@@ -119,10 +119,11 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
import collection.parallel.tasksupport._
private var mask = ParHashSetCombiner.discriminantmask
private var nonmasklen = ParHashSetCombiner.nonmasklength
-
+ private var seedvalue = 27
+
def +=(elem: T) = {
sz += 1
- val hc = improve(elemHashCode(elem))
+ val hc = improve(elemHashCode(elem), seedvalue)
val pos = hc >>> nonmasklen
if (buckets(pos) eq null) {
// initialize bucket
@@ -140,7 +141,7 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
private def parPopulate: FlatHashTable.Contents[T] = {
// construct it in parallel
- val table = new AddingFlatHashTable(size, tableLoadFactor)
+ val table = new AddingFlatHashTable(size, tableLoadFactor, seedvalue)
val (inserted, leftovers) = executeAndWaitResult(new FillBlocks(buckets, table, 0, buckets.length))
var leftinserts = 0
for (elem <- leftovers) leftinserts += table.insertEntry(0, table.tableLength, elem.asInstanceOf[T])
@@ -153,6 +154,7 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
// TODO parallelize by keeping separate size maps and merging them
val tbl = new FlatHashTable[T] {
sizeMapInit(table.length)
+ seedvalue = ParHashSetCombiner.this.seedvalue
}
for {
buffer <- buckets;
@@ -168,13 +170,13 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
* it has to take and allocates the underlying hash table in advance.
* Elements can only be added to it. The final size has to be adjusted manually.
* It is internal to `ParHashSet` combiners.
- *
*/
- class AddingFlatHashTable(numelems: Int, lf: Int) extends FlatHashTable[T] {
+ class AddingFlatHashTable(numelems: Int, lf: Int, inseedvalue: Int) extends FlatHashTable[T] {
_loadFactor = lf
table = new Array[AnyRef](capacity(FlatHashTable.sizeForThreshold(numelems, _loadFactor)))
tableSize = 0
threshold = FlatHashTable.newThreshold(_loadFactor, table.length)
+ seedvalue = inseedvalue
sizeMapInit(table.length)
override def toString = "AFHT(%s)".format(table.length)
@@ -310,6 +312,7 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
}
+
private[parallel] object ParHashSetCombiner {
private[mutable] val discriminantbits = 5
private[mutable] val numblocks = 1 << discriminantbits
diff --git a/test/files/jvm/serialization.check b/test/files/jvm/serialization.check
index 8704bcc643..15708f0c3b 100644
--- a/test/files/jvm/serialization.check
+++ b/test/files/jvm/serialization.check
@@ -160,8 +160,8 @@ x = Map(C -> 3, B -> 2, A -> 1)
y = Map(C -> 3, A -> 1, B -> 2)
x equals y: true, y equals x: true
-x = Set(layers, title, buffers)
-y = Set(layers, title, buffers)
+x = Set(buffers, title, layers)
+y = Set(buffers, title, layers)
x equals y: true, y equals x: true
x = History()
@@ -279,8 +279,8 @@ x = ParHashMap(1 -> 2, 2 -> 4)
y = ParHashMap(1 -> 2, 2 -> 4)
x equals y: true, y equals x: true
-x = ParHashSet(2, 1, 3)
-y = ParHashSet(2, 1, 3)
+x = ParHashSet(1, 2, 3)
+y = ParHashSet(1, 2, 3)
x equals y: true, y equals x: true
x = ParRange(0, 1, 2, 3, 4)
diff --git a/test/files/run/t5293.scala b/test/files/run/t5293.scala
new file mode 100644
index 0000000000..de1efaec4a
--- /dev/null
+++ b/test/files/run/t5293.scala
@@ -0,0 +1,83 @@
+
+
+
+import scala.collection.JavaConverters._
+
+
+
+object Test extends App {
+
+ def bench(label: String)(body: => Unit): Long = {
+ val start = System.nanoTime
+
+ 0.until(10).foreach(_ => body)
+
+ val end = System.nanoTime
+
+ //println("%s: %s ms".format(label, (end - start) / 1000.0 / 1000.0))
+
+ end - start
+ }
+
+ def benchJava(values: java.util.Collection[Int]) = {
+ bench("Java Set") {
+ val set = new java.util.HashSet[Int]
+
+ set.addAll(values)
+ }
+ }
+
+ def benchScala(values: Iterable[Int]) = {
+ bench("Scala Set") {
+ val set = new scala.collection.mutable.HashSet[Int]
+
+ set ++= values
+ }
+ }
+
+ def benchScalaSorted(values: Iterable[Int]) = {
+ bench("Scala Set sorted") {
+ val set = new scala.collection.mutable.HashSet[Int]
+
+ set ++= values.toArray.sorted
+ }
+ }
+
+ def benchScalaPar(values: Iterable[Int]) = {
+ bench("Scala ParSet") {
+ val set = new scala.collection.parallel.mutable.ParHashSet[Int] map { x => x }
+
+ set ++= values
+ }
+ }
+
+ val values = 0 until 50000
+ val set = scala.collection.mutable.HashSet.empty[Int]
+
+ set ++= values
+
+ // warmup
+ for (x <- 0 until 5) {
+ benchJava(set.asJava)
+ benchScala(set)
+ benchScalaPar(set)
+ benchJava(set.asJava)
+ benchScala(set)
+ benchScalaPar(set)
+ }
+
+ val javaset = benchJava(set.asJava)
+ val scalaset = benchScala(set)
+ val scalaparset = benchScalaPar(set)
+
+ assert(scalaset < (javaset * 4))
+ assert(scalaparset < (javaset * 4))
+}
+
+
+
+
+
+
+
+