summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoraleksandar <aleksandar@lampmac14.epfl.ch>2011-12-19 15:21:59 +0100
committeraleksandar <aleksandar@lampmac14.epfl.ch>2011-12-19 16:26:13 +0100
commit832e3179cb2d0e3dbf1ff63234ec0cbc36a2b2fe (patch)
tree772423693a21a7184715456574b5366149f324f6
parentd1e3b46f5bf58469bffb6f8e2ffebd932b990a5d (diff)
downloadscala-832e3179cb2d0e3dbf1ff63234ec0cbc36a2b2fe.tar.gz
scala-832e3179cb2d0e3dbf1ff63234ec0cbc36a2b2fe.tar.bz2
scala-832e3179cb2d0e3dbf1ff63234ec0cbc36a2b2fe.zip
Fix #5293 - changed the way hashcode is improved in hash sets.
The hash code is further improved by using a special value in the hash sets called a `seed`. For sequential hash tables, this value depends on the size of the hash table. It determines the number of bits the hashcode should be rotated. This ensures that hash tables with different sizes use different bits to compute the position of the element. This way traversing the elements of the source hash table will yield them in the order where they had similar hashcodes (and hence, positions) in the source table, but different ones in the destination table. Ideally, in the future we want to be able to have a family of hash functions and assign a different hash function from that family to each hash table instance. That would statistically almost completely eliminate the possibility that the hash table element traversal causes excessive collisions. I should probably @mention extempore here.
-rw-r--r--src/library/scala/collection/mutable/FlatHashTable.scala99
-rw-r--r--src/library/scala/collection/parallel/mutable/ParHashSet.scala13
-rw-r--r--test/files/jvm/serialization.check8
-rw-r--r--test/files/run/t5293.scala83
4 files changed, 168 insertions, 35 deletions
diff --git a/src/library/scala/collection/mutable/FlatHashTable.scala b/src/library/scala/collection/mutable/FlatHashTable.scala
index 0740d97e09..f3fb6738eb 100644
--- a/src/library/scala/collection/mutable/FlatHashTable.scala
+++ b/src/library/scala/collection/mutable/FlatHashTable.scala
@@ -24,7 +24,7 @@ package mutable
trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
import FlatHashTable._
- private final val tableDebug = false
+ private final def tableDebug = false
@transient private[collection] var _loadFactor = defaultLoadFactor
@@ -43,11 +43,19 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
/** The array keeping track of number of elements in 32 element blocks.
*/
@transient protected var sizemap: Array[Int] = null
-
+
+ @transient var seedvalue: Int = tableSizeSeed
+
import HashTable.powerOfTwo
+
protected def capacity(expectedSize: Int) = if (expectedSize == 0) 1 else powerOfTwo(expectedSize)
+
private def initialCapacity = capacity(initialSize)
-
+
+ protected def randomSeed = seedGenerator.get.nextInt()
+
+ protected def tableSizeSeed = Integer.bitCount(table.length - 1)
+
/**
* Initializes the collection from the input stream. `f` will be called for each element
* read from the input stream in the order determined by the stream. This is useful for
@@ -57,23 +65,25 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
*/
private[collection] def init(in: java.io.ObjectInputStream, f: A => Unit) {
in.defaultReadObject
-
- _loadFactor = in.readInt
+
+ _loadFactor = in.readInt()
assert(_loadFactor > 0)
-
- val size = in.readInt
+
+ val size = in.readInt()
tableSize = 0
assert(size >= 0)
-
+
table = new Array(capacity(sizeForThreshold(size, _loadFactor)))
threshold = newThreshold(_loadFactor, table.size)
-
- val smDefined = in.readBoolean
+
+ seedvalue = in.readInt()
+
+ val smDefined = in.readBoolean()
if (smDefined) sizeMapInit(table.length) else sizemap = null
-
+
var index = 0
while (index < size) {
- val elem = in.readObject.asInstanceOf[A]
+ val elem = in.readObject().asInstanceOf[A]
f(elem)
addEntry(elem)
index += 1
@@ -89,6 +99,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
out.defaultWriteObject
out.writeInt(_loadFactor)
out.writeInt(tableSize)
+ out.writeInt(seedvalue)
out.writeBoolean(isSizeMapDefined)
iterator.foreach(out.writeObject)
}
@@ -125,6 +136,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
if (entry == elem) return false
h = (h + 1) % table.length
entry = table(h)
+ //Statistics.collisions += 1
}
table(h) = elem.asInstanceOf[AnyRef]
tableSize = tableSize + 1
@@ -185,6 +197,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
table = new Array[AnyRef](table.length * 2)
tableSize = 0
nnSizeMapReset(table.length)
+ seedvalue = tableSizeSeed
threshold = newThreshold(_loadFactor, table.length)
var i = 0
while (i < oldtable.length) {
@@ -280,10 +293,24 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
/* End of size map handling code */
protected final def index(hcode: Int) = {
+ // version 1 (no longer used - did not work with parallel hash tables)
// improve(hcode) & (table.length - 1)
- val improved = improve(hcode)
+
+ // version 2 (allows for parallel hash table construction)
+ val improved = improve(hcode, seedvalue)
val ones = table.length - 1
(improved >>> (32 - java.lang.Integer.bitCount(ones))) & ones
+
+ // version 3 (solves SI-5293 in most cases, but such a case would still arise for parallel hash tables)
+ // val hc = improve(hcode)
+ // val bbp = blockbitpos
+ // val ones = table.length - 1
+ // val needed = Integer.bitCount(ones)
+ // val blockbits = ((hc >>> bbp) & 0x1f) << (needed - 5)
+ // val rest = ((hc >>> (bbp + 5)) << bbp) | (((1 << bbp) - 1) & hc)
+ // val restmask = (1 << (needed - 5)) - 1
+ // val improved = blockbits | (rest & restmask)
+ // improved
}
protected def clearTable() {
@@ -298,6 +325,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
table,
tableSize,
threshold,
+ seedvalue,
sizemap
)
@@ -307,6 +335,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
table = c.table
tableSize = c.tableSize
threshold = c.threshold
+ seedvalue = c.seedvalue
sizemap = c.sizemap
}
if (alwaysInitSizeMap && sizemap == null) sizeMapInitAndRebuild
@@ -315,21 +344,30 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
}
-
private[collection] object FlatHashTable {
-
+
+ /** Creates a specific seed to improve hashcode of a hash table instance
+ * and ensure that iteration order vulnerabilities are not 'felt' in other
+ * hash tables.
+ *
+ * See SI-5293.
+ */
+ final def seedGenerator = new ThreadLocal[util.Random] {
+ override def initialValue = new util.Random
+ }
+
/** The load factor for the hash table; must be < 500 (0.5)
*/
- private[collection] def defaultLoadFactor: Int = 450
- private[collection] final def loadFactorDenum = 1000
+ def defaultLoadFactor: Int = 450
+ final def loadFactorDenum = 1000
/** The initial size of the hash table.
*/
- private[collection] def initialSize: Int = 16
+ def initialSize: Int = 32
- private[collection] def sizeForThreshold(size: Int, _loadFactor: Int) = (size.toLong * loadFactorDenum / _loadFactor).toInt
+ def sizeForThreshold(size: Int, _loadFactor: Int) = math.max(32, (size.toLong * loadFactorDenum / _loadFactor).toInt)
- private[collection] def newThreshold(_loadFactor: Int, size: Int) = {
+ def newThreshold(_loadFactor: Int, size: Int) = {
val lf = _loadFactor
assert(lf < (loadFactorDenum / 2), "loadFactor too large; must be < 0.5")
(size.toLong * lf / loadFactorDenum ).toInt
@@ -340,6 +378,7 @@ private[collection] object FlatHashTable {
val table: Array[AnyRef],
val tableSize: Int,
val threshold: Int,
+ val seedvalue: Int,
val sizemap: Array[Int]
)
@@ -352,16 +391,24 @@ private[collection] object FlatHashTable {
if (elem == null) throw new IllegalArgumentException("Flat hash tables cannot contain null elements.")
else elem.hashCode()
- protected final def improve(hcode: Int) = {
- // var h: Int = hcode + ~(hcode << 9)
- // h = h ^ (h >>> 14)
- // h = h + (h << 4)
- // h ^ (h >>> 10)
+ protected final def improve(hcode: Int, seed: Int) = {
+ //var h: Int = hcode + ~(hcode << 9)
+ //h = h ^ (h >>> 14)
+ //h = h + (h << 4)
+ //h ^ (h >>> 10)
+
var i = hcode * 0x9e3775cd
i = java.lang.Integer.reverseBytes(i)
- i * 0x9e3775cd
+ val improved = i * 0x9e3775cd
+
+ // for the remainder, see SI-5293
+ // to ensure that different bits are used for different hash tables, we have to rotate based on the seed
+ val rotation = seed % 32
+ val rotated = (improved >>> rotation) | (improved << (32 - rotation))
+ rotated
}
}
}
+
diff --git a/src/library/scala/collection/parallel/mutable/ParHashSet.scala b/src/library/scala/collection/parallel/mutable/ParHashSet.scala
index 9dbc7dc6c4..7763cdf318 100644
--- a/src/library/scala/collection/parallel/mutable/ParHashSet.scala
+++ b/src/library/scala/collection/parallel/mutable/ParHashSet.scala
@@ -119,10 +119,11 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
import collection.parallel.tasksupport._
private var mask = ParHashSetCombiner.discriminantmask
private var nonmasklen = ParHashSetCombiner.nonmasklength
-
+ private var seedvalue = 27
+
def +=(elem: T) = {
sz += 1
- val hc = improve(elemHashCode(elem))
+ val hc = improve(elemHashCode(elem), seedvalue)
val pos = hc >>> nonmasklen
if (buckets(pos) eq null) {
// initialize bucket
@@ -140,7 +141,7 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
private def parPopulate: FlatHashTable.Contents[T] = {
// construct it in parallel
- val table = new AddingFlatHashTable(size, tableLoadFactor)
+ val table = new AddingFlatHashTable(size, tableLoadFactor, seedvalue)
val (inserted, leftovers) = executeAndWaitResult(new FillBlocks(buckets, table, 0, buckets.length))
var leftinserts = 0
for (elem <- leftovers) leftinserts += table.insertEntry(0, table.tableLength, elem.asInstanceOf[T])
@@ -153,6 +154,7 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
// TODO parallelize by keeping separate size maps and merging them
val tbl = new FlatHashTable[T] {
sizeMapInit(table.length)
+ seedvalue = ParHashSetCombiner.this.seedvalue
}
for {
buffer <- buckets;
@@ -168,13 +170,13 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
* it has to take and allocates the underlying hash table in advance.
* Elements can only be added to it. The final size has to be adjusted manually.
* It is internal to `ParHashSet` combiners.
- *
*/
- class AddingFlatHashTable(numelems: Int, lf: Int) extends FlatHashTable[T] {
+ class AddingFlatHashTable(numelems: Int, lf: Int, inseedvalue: Int) extends FlatHashTable[T] {
_loadFactor = lf
table = new Array[AnyRef](capacity(FlatHashTable.sizeForThreshold(numelems, _loadFactor)))
tableSize = 0
threshold = FlatHashTable.newThreshold(_loadFactor, table.length)
+ seedvalue = inseedvalue
sizeMapInit(table.length)
override def toString = "AFHT(%s)".format(table.length)
@@ -310,6 +312,7 @@ with collection.mutable.FlatHashTable.HashUtils[T] {
}
+
private[parallel] object ParHashSetCombiner {
private[mutable] val discriminantbits = 5
private[mutable] val numblocks = 1 << discriminantbits
diff --git a/test/files/jvm/serialization.check b/test/files/jvm/serialization.check
index 8704bcc643..15708f0c3b 100644
--- a/test/files/jvm/serialization.check
+++ b/test/files/jvm/serialization.check
@@ -160,8 +160,8 @@ x = Map(C -> 3, B -> 2, A -> 1)
y = Map(C -> 3, A -> 1, B -> 2)
x equals y: true, y equals x: true
-x = Set(layers, title, buffers)
-y = Set(layers, title, buffers)
+x = Set(buffers, title, layers)
+y = Set(buffers, title, layers)
x equals y: true, y equals x: true
x = History()
@@ -279,8 +279,8 @@ x = ParHashMap(1 -> 2, 2 -> 4)
y = ParHashMap(1 -> 2, 2 -> 4)
x equals y: true, y equals x: true
-x = ParHashSet(2, 1, 3)
-y = ParHashSet(2, 1, 3)
+x = ParHashSet(1, 2, 3)
+y = ParHashSet(1, 2, 3)
x equals y: true, y equals x: true
x = ParRange(0, 1, 2, 3, 4)
diff --git a/test/files/run/t5293.scala b/test/files/run/t5293.scala
new file mode 100644
index 0000000000..de1efaec4a
--- /dev/null
+++ b/test/files/run/t5293.scala
@@ -0,0 +1,83 @@
+
+
+
+import scala.collection.JavaConverters._
+
+
+
+object Test extends App {
+
+ def bench(label: String)(body: => Unit): Long = {
+ val start = System.nanoTime
+
+ 0.until(10).foreach(_ => body)
+
+ val end = System.nanoTime
+
+ //println("%s: %s ms".format(label, (end - start) / 1000.0 / 1000.0))
+
+ end - start
+ }
+
+ def benchJava(values: java.util.Collection[Int]) = {
+ bench("Java Set") {
+ val set = new java.util.HashSet[Int]
+
+ set.addAll(values)
+ }
+ }
+
+ def benchScala(values: Iterable[Int]) = {
+ bench("Scala Set") {
+ val set = new scala.collection.mutable.HashSet[Int]
+
+ set ++= values
+ }
+ }
+
+ def benchScalaSorted(values: Iterable[Int]) = {
+ bench("Scala Set sorted") {
+ val set = new scala.collection.mutable.HashSet[Int]
+
+ set ++= values.toArray.sorted
+ }
+ }
+
+ def benchScalaPar(values: Iterable[Int]) = {
+ bench("Scala ParSet") {
+ val set = new scala.collection.parallel.mutable.ParHashSet[Int] map { x => x }
+
+ set ++= values
+ }
+ }
+
+ val values = 0 until 50000
+ val set = scala.collection.mutable.HashSet.empty[Int]
+
+ set ++= values
+
+ // warmup
+ for (x <- 0 until 5) {
+ benchJava(set.asJava)
+ benchScala(set)
+ benchScalaPar(set)
+ benchJava(set.asJava)
+ benchScala(set)
+ benchScalaPar(set)
+ }
+
+ val javaset = benchJava(set.asJava)
+ val scalaset = benchScala(set)
+ val scalaparset = benchScalaPar(set)
+
+ assert(scalaset < (javaset * 4))
+ assert(scalaparset < (javaset * 4))
+}
+
+
+
+
+
+
+
+