summaryrefslogtreecommitdiff
path: root/src/library
diff options
context:
space:
mode:
authorJames Iry <james.iry@typesafe.com>2013-01-03 15:55:18 -0800
committerJames Iry <james.iry@typesafe.com>2013-01-03 15:55:18 -0800
commit016763cc3a3045cdc0ae21007f50ae52c3c7d642 (patch)
treec4aeb736fdd6dfc7a0361041d653732fc72602ab /src/library
parent666572261c41cc92b06c03bf4aa260c198240cd8 (diff)
downloadscala-016763cc3a3045cdc0ae21007f50ae52c3c7d642.tar.gz
scala-016763cc3a3045cdc0ae21007f50ae52c3c7d642.tar.bz2
scala-016763cc3a3045cdc0ae21007f50ae52c3c7d642.zip
SI-6908 Makes FlatHashTable as well as derived classes support null values
This change adds a null sentinel object which is used to indicate that a null value has been inserted in FlatHashTable. It also makes a strong distinction between logical elements of the Set vs entries in the hash table. Changes are made to mutable.HashSet and ParHashSet accordingly.
Diffstat (limited to 'src/library')
-rw-r--r--src/library/scala/collection/mutable/FlatHashTable.scala100
-rw-r--r--src/library/scala/collection/mutable/HashSet.scala14
-rw-r--r--src/library/scala/collection/parallel/mutable/ParHashSet.scala61
3 files changed, 101 insertions, 74 deletions
diff --git a/src/library/scala/collection/mutable/FlatHashTable.scala b/src/library/scala/collection/mutable/FlatHashTable.scala
index 7ab99fcda2..0f961b20d3 100644
--- a/src/library/scala/collection/mutable/FlatHashTable.scala
+++ b/src/library/scala/collection/mutable/FlatHashTable.scala
@@ -87,9 +87,9 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
var index = 0
while (index < size) {
- val elem = in.readObject().asInstanceOf[A]
+ val elem = entryToElem(in.readObject())
f(elem)
- addEntry(elem)
+ addElem(elem)
index += 1
}
}
@@ -110,60 +110,72 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
/** Finds an entry in the hash table if such an element exists. */
protected def findEntry(elem: A): Option[A] = {
- val entry = findEntryImpl(elem)
- if (null == entry) None else Some(entry.asInstanceOf[A])
+ val entry = findElemImpl(elem)
+ if (null == entry) None else Some(entryToElem(entry))
}
/** Checks whether an element is contained in the hash table. */
- protected def containsEntry(elem: A): Boolean = {
- null != findEntryImpl(elem)
+ protected def containsElem(elem: A): Boolean = {
+ null != findElemImpl(elem)
}
- private def findEntryImpl(elem: A): AnyRef = {
- var h = index(elemHashCode(elem))
- var entry = table(h)
- while (null != entry && entry != elem) {
+ private def findElemImpl(elem: A): AnyRef = {
+ var searchEntry = elemToEntry(elem)
+ var h = index(searchEntry.hashCode)
+ var curEntry = table(h)
+ while (null != curEntry && curEntry != searchEntry) {
h = (h + 1) % table.length
- entry = table(h)
+ curEntry = table(h)
}
- entry
+ curEntry
}
- /** Add entry if not yet in table.
- * @return Returns `true` if a new entry was added, `false` otherwise.
+ /** Add elem if not yet in table.
+ * @return Returns `true` if a new elem was added, `false` otherwise.
*/
- protected def addEntry(elem: A) : Boolean = {
- var h = index(elemHashCode(elem))
- var entry = table(h)
- while (null != entry) {
- if (entry == elem) return false
+ protected def addElem(elem: A) : Boolean = {
+ addEntry(elemToEntry(elem))
+ }
+
+ /**
+ * Add an entry (an elem converted to an entry via elemToEntry) if not yet in
+ * table.
+ * @return Returns `true` if a new elem was added, `false` otherwise.
+ */
+ protected def addEntry(newEntry : AnyRef) : Boolean = {
+ var h = index(newEntry.hashCode)
+ var curEntry = table(h)
+ while (null != curEntry) {
+ if (curEntry == newEntry) return false
h = (h + 1) % table.length
- entry = table(h)
+ curEntry = table(h)
//Statistics.collisions += 1
}
- table(h) = elem.asInstanceOf[AnyRef]
+ table(h) = newEntry
tableSize = tableSize + 1
nnSizeMapAdd(h)
if (tableSize >= threshold) growTable()
true
+
}
- /** Removes an entry from the hash table, returning an option value with the element, or `None` if it didn't exist. */
- protected def removeEntry(elem: A) : Option[A] = {
+ /** Removes an elem from the hash table, returning an option value with the element, or `None` if it didn't exist. */
+ protected def removeElem(elem: A) : Option[A] = {
if (tableDebug) checkConsistent()
def precedes(i: Int, j: Int) = {
val d = table.length >> 1
if (i <= j) j - i < d
else i - j > d
}
- var h = index(elemHashCode(elem))
- var entry = table(h)
- while (null != entry) {
- if (entry == elem) {
+ val removalEntry = elemToEntry(elem)
+ var h = index(removalEntry.hashCode)
+ var curEntry = table(h)
+ while (null != curEntry) {
+ if (curEntry == removalEntry) {
var h0 = h
var h1 = (h0 + 1) % table.length
while (null != table(h1)) {
- val h2 = index(elemHashCode(table(h1).asInstanceOf[A]))
+ val h2 = index(table(h1).hashCode)
//Console.println("shift at "+h1+":"+table(h1)+" with h2 = "+h2+"? "+(h2 != h1)+precedes(h2, h0)+table.length)
if (h2 != h1 && precedes(h2, h0)) {
//Console.println("shift "+h1+" to "+h0+"!")
@@ -176,10 +188,10 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
tableSize -= 1
nnSizeMapRemove(h0)
if (tableDebug) checkConsistent()
- return Some(entry.asInstanceOf[A])
+ return Some(entryToElem(curEntry))
}
h = (h + 1) % table.length
- entry = table(h)
+ curEntry = table(h)
}
None
}
@@ -191,7 +203,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
i < table.length
}
def next(): A =
- if (hasNext) { i += 1; table(i - 1).asInstanceOf[A] }
+ if (hasNext) { i += 1; entryToElem(table(i - 1)) }
else Iterator.empty.next
}
@@ -205,7 +217,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
var i = 0
while (i < oldtable.length) {
val entry = oldtable(i)
- if (null != entry) addEntry(entry.asInstanceOf[A])
+ if (null != entry) addEntry(entry)
i += 1
}
if (tableDebug) checkConsistent()
@@ -213,9 +225,10 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] {
private def checkConsistent() {
for (i <- 0 until table.length)
- if (table(i) != null && !containsEntry(table(i).asInstanceOf[A]))
+ if (table(i) != null && !containsElem(entryToElem(table(i))))
assert(false, i+" "+table(i)+" "+table.mkString)
}
+
/* Size map handling code */
@@ -358,6 +371,10 @@ private[collection] object FlatHashTable {
final def seedGenerator = new ThreadLocal[scala.util.Random] {
override def initialValue = new scala.util.Random
}
+
+ val NullSentinel = new AnyRef {
+ override def hashCode = 0
+ }
/** The load factor for the hash table; must be < 500 (0.5)
*/
@@ -386,10 +403,6 @@ private[collection] object FlatHashTable {
// so that:
protected final def sizeMapBucketSize = 1 << sizeMapBucketBitSize
- protected def elemHashCode(elem: A) =
- if (elem == null) throw new IllegalArgumentException("Flat hash tables cannot contain null elements.")
- else elem.hashCode()
-
protected final def improve(hcode: Int, seed: Int) = {
//var h: Int = hcode + ~(hcode << 9)
//h = h ^ (h >>> 14)
@@ -404,6 +417,19 @@ private[collection] object FlatHashTable {
val rotated = (improved >>> rotation) | (improved << (32 - rotation))
rotated
}
+
+ /**
+ * Elems have type A, but we store AnyRef in the table. Plus we need to deal with
+ * null elems, which need to be stored as NullSentinel
+ */
+ protected final def elemToEntry(elem : A) : AnyRef =
+ if (null == elem) NullSentinel else elem.asInstanceOf[AnyRef]
+
+ /**
+ * Does the inverse translation of elemToEntry
+ */
+ protected final def entryToElem(entry : AnyRef) : A =
+ (if (NullSentinel eq entry) null else entry).asInstanceOf[A]
}
}
diff --git a/src/library/scala/collection/mutable/HashSet.scala b/src/library/scala/collection/mutable/HashSet.scala
index 74f2a6c762..d2e4bd78b3 100644
--- a/src/library/scala/collection/mutable/HashSet.scala
+++ b/src/library/scala/collection/mutable/HashSet.scala
@@ -55,17 +55,17 @@ extends AbstractSet[A]
override def size: Int = tableSize
- def contains(elem: A): Boolean = containsEntry(elem)
+ def contains(elem: A): Boolean = containsElem(elem)
- def += (elem: A): this.type = { addEntry(elem); this }
+ def += (elem: A): this.type = { addElem(elem); this }
- def -= (elem: A): this.type = { removeEntry(elem); this }
+ def -= (elem: A): this.type = { removeElem(elem); this }
override def par = new ParHashSet(hashTableContents)
- override def add(elem: A): Boolean = addEntry(elem)
+ override def add(elem: A): Boolean = addElem(elem)
- override def remove(elem: A): Boolean = removeEntry(elem).isDefined
+ override def remove(elem: A): Boolean = removeElem(elem).isDefined
override def clear() { clearTable() }
@@ -75,8 +75,8 @@ extends AbstractSet[A]
var i = 0
val len = table.length
while (i < len) {
- val elem = table(i)
- if (elem ne null) f(elem.asInstanceOf[A])
+ val curEntry = table(i)
+ if (curEntry ne null) f(entryToElem(curEntry))
i += 1
}
}
diff --git a/src/library/scala/collection/parallel/mutable/ParHashSet.scala b/src/library/scala/collection/parallel/mutable/ParHashSet.scala
index f1377fb0a7..68a5f20cb5 100644
--- a/src/library/scala/collection/parallel/mutable/ParHashSet.scala
+++ b/src/library/scala/collection/parallel/mutable/ParHashSet.scala
@@ -60,18 +60,18 @@ extends ParSet[T]
override def seq = new scala.collection.mutable.HashSet(hashTableContents)
def +=(elem: T) = {
- addEntry(elem)
+ addElem(elem)
this
}
def -=(elem: T) = {
- removeEntry(elem)
+ removeElem(elem)
this
}
override def stringPrefix = "ParHashSet"
- def contains(elem: T) = containsEntry(elem)
+ def contains(elem: T) = containsElem(elem)
def splitter = new ParHashSetIterator(0, table.length, size)
@@ -117,22 +117,23 @@ object ParHashSet extends ParSetFactory[ParHashSet] {
private[mutable] abstract class ParHashSetCombiner[T](private val tableLoadFactor: Int)
-extends scala.collection.parallel.BucketCombiner[T, ParHashSet[T], Any, ParHashSetCombiner[T]](ParHashSetCombiner.numblocks)
+extends scala.collection.parallel.BucketCombiner[T, ParHashSet[T], AnyRef, ParHashSetCombiner[T]](ParHashSetCombiner.numblocks)
with scala.collection.mutable.FlatHashTable.HashUtils[T] {
//self: EnvironmentPassingCombiner[T, ParHashSet[T]] =>
private val nonmasklen = ParHashSetCombiner.nonmasklength
private val seedvalue = 27
def +=(elem: T) = {
+ val entry = elemToEntry(elem)
sz += 1
- val hc = improve(elemHashCode(elem), seedvalue)
+ val hc = improve(entry.hashCode, seedvalue)
val pos = hc >>> nonmasklen
if (buckets(pos) eq null) {
// initialize bucket
- buckets(pos) = new UnrolledBuffer[Any]
+ buckets(pos) = new UnrolledBuffer[AnyRef]
}
// add to bucket
- buckets(pos) += elem
+ buckets(pos) += entry
this
}
@@ -146,7 +147,7 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] {
val table = new AddingFlatHashTable(size, tableLoadFactor, seedvalue)
val (inserted, leftovers) = combinerTaskSupport.executeAndWaitResult(new FillBlocks(buckets, table, 0, buckets.length))
var leftinserts = 0
- for (elem <- leftovers) leftinserts += table.insertEntry(0, table.tableLength, elem.asInstanceOf[T])
+ for (entry <- leftovers) leftinserts += table.insertEntry(0, table.tableLength, entry)
table.setSize(leftinserts + inserted)
table.hashTableContents
}
@@ -160,8 +161,8 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] {
for {
buffer <- buckets;
if buffer ne null;
- elem <- buffer
- } addEntry(elem.asInstanceOf[T])
+ entry <- buffer
+ } addEntry(entry)
}
tbl.hashTableContents
}
@@ -188,7 +189,7 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] {
def setSize(sz: Int) = tableSize = sz
/**
- * The elements are added using the `insertEntry` method. This method accepts three
+ * The elements are added using the `insertElem` method. This method accepts three
* arguments:
*
* @param insertAt where to add the element (set to -1 to use its hashcode)
@@ -205,17 +206,17 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] {
* If the element is already present in the hash table, it is not added, and this method
* returns 0. If the element is added, it returns 1.
*/
- def insertEntry(insertAt: Int, comesBefore: Int, elem: T): Int = {
+ def insertEntry(insertAt: Int, comesBefore: Int, newEntry : AnyRef): Int = {
var h = insertAt
- if (h == -1) h = index(elemHashCode(elem))
- var entry = table(h)
- while (null != entry) {
- if (entry == elem) return 0
+ if (h == -1) h = index(newEntry.hashCode)
+ var curEntry = table(h)
+ while (null != curEntry) {
+ if (curEntry == newEntry) return 0
h = h + 1 // we *do not* do `(h + 1) % table.length` here, because we'll never overflow!!
if (h >= comesBefore) return -1
- entry = table(h)
+ curEntry = table(h)
}
- table(h) = elem.asInstanceOf[AnyRef]
+ table(h) = newEntry
// this is incorrect since we set size afterwards anyway and a counter
// like this would not even work:
@@ -232,13 +233,13 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] {
/* tasks */
- class FillBlocks(buckets: Array[UnrolledBuffer[Any]], table: AddingFlatHashTable, val offset: Int, val howmany: Int)
- extends Task[(Int, UnrolledBuffer[Any]), FillBlocks] {
- var result = (Int.MinValue, new UnrolledBuffer[Any]);
- def leaf(prev: Option[(Int, UnrolledBuffer[Any])]) {
+ class FillBlocks(buckets: Array[UnrolledBuffer[AnyRef]], table: AddingFlatHashTable, val offset: Int, val howmany: Int)
+ extends Task[(Int, UnrolledBuffer[AnyRef]), FillBlocks] {
+ var result = (Int.MinValue, new UnrolledBuffer[AnyRef]);
+ def leaf(prev: Option[(Int, UnrolledBuffer[AnyRef])]) {
var i = offset
var totalinserts = 0
- var leftover = new UnrolledBuffer[Any]()
+ var leftover = new UnrolledBuffer[AnyRef]()
while (i < (offset + howmany)) {
val (inserted, intonextblock) = fillBlock(i, buckets(i), leftover)
totalinserts += inserted
@@ -250,11 +251,11 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] {
private val blocksize = table.tableLength >> ParHashSetCombiner.discriminantbits
private def blockStart(block: Int) = block * blocksize
private def nextBlockStart(block: Int) = (block + 1) * blocksize
- private def fillBlock(block: Int, elems: UnrolledBuffer[Any], leftovers: UnrolledBuffer[Any]): (Int, UnrolledBuffer[Any]) = {
+ private def fillBlock(block: Int, elems: UnrolledBuffer[AnyRef], leftovers: UnrolledBuffer[AnyRef]): (Int, UnrolledBuffer[AnyRef]) = {
val beforePos = nextBlockStart(block)
// store the elems
- val (elemsIn, elemsLeft) = if (elems != null) insertAll(-1, beforePos, elems) else (0, UnrolledBuffer[Any]())
+ val (elemsIn, elemsLeft) = if (elems != null) insertAll(-1, beforePos, elems) else (0, UnrolledBuffer[AnyRef]())
// store the leftovers
val (leftoversIn, leftoversLeft) = insertAll(blockStart(block), beforePos, leftovers)
@@ -262,8 +263,8 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] {
// return the no. of stored elements tupled with leftovers
(elemsIn + leftoversIn, elemsLeft concat leftoversLeft)
}
- private def insertAll(atPos: Int, beforePos: Int, elems: UnrolledBuffer[Any]): (Int, UnrolledBuffer[Any]) = {
- val leftovers = new UnrolledBuffer[Any]
+ private def insertAll(atPos: Int, beforePos: Int, elems: UnrolledBuffer[AnyRef]): (Int, UnrolledBuffer[AnyRef]) = {
+ val leftovers = new UnrolledBuffer[AnyRef]
var inserted = 0
var unrolled = elems.headPtr
@@ -273,10 +274,10 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] {
val chunkarr = unrolled.array
val chunksz = unrolled.size
while (i < chunksz) {
- val elem = chunkarr(i)
- val res = t.insertEntry(atPos, beforePos, elem.asInstanceOf[T])
+ val entry = chunkarr(i)
+ val res = t.insertEntry(atPos, beforePos, entry)
if (res >= 0) inserted += res
- else leftovers += elem
+ else leftovers += entry
i += 1
}
i = 0