From 016763cc3a3045cdc0ae21007f50ae52c3c7d642 Mon Sep 17 00:00:00 2001 From: James Iry Date: Thu, 3 Jan 2013 15:55:18 -0800 Subject: SI-6908 Makes FlatHashTable as well as derived classes support null values This change adds a null sentinel object which is used to indicate that a null value has been inserted in FlatHashTable. It also makes a strong distinction between logical elements of the Set vs entries in the hash table. Changes are made to mutable.HashSet and ParHashSet accordingly. --- .../scala/collection/mutable/FlatHashTable.scala | 100 +++++++++++++-------- src/library/scala/collection/mutable/HashSet.scala | 14 +-- .../collection/parallel/mutable/ParHashSet.scala | 61 ++++++------- 3 files changed, 101 insertions(+), 74 deletions(-) (limited to 'src/library') diff --git a/src/library/scala/collection/mutable/FlatHashTable.scala b/src/library/scala/collection/mutable/FlatHashTable.scala index 7ab99fcda2..0f961b20d3 100644 --- a/src/library/scala/collection/mutable/FlatHashTable.scala +++ b/src/library/scala/collection/mutable/FlatHashTable.scala @@ -87,9 +87,9 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] { var index = 0 while (index < size) { - val elem = in.readObject().asInstanceOf[A] + val elem = entryToElem(in.readObject()) f(elem) - addEntry(elem) + addElem(elem) index += 1 } } @@ -110,60 +110,72 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] { /** Finds an entry in the hash table if such an element exists. */ protected def findEntry(elem: A): Option[A] = { - val entry = findEntryImpl(elem) - if (null == entry) None else Some(entry.asInstanceOf[A]) + val entry = findElemImpl(elem) + if (null == entry) None else Some(entryToElem(entry)) } /** Checks whether an element is contained in the hash table. */ - protected def containsEntry(elem: A): Boolean = { - null != findEntryImpl(elem) + protected def containsElem(elem: A): Boolean = { + null != findElemImpl(elem) } - private def findEntryImpl(elem: A): AnyRef = { - var h = index(elemHashCode(elem)) - var entry = table(h) - while (null != entry && entry != elem) { + private def findElemImpl(elem: A): AnyRef = { + var searchEntry = elemToEntry(elem) + var h = index(searchEntry.hashCode) + var curEntry = table(h) + while (null != curEntry && curEntry != searchEntry) { h = (h + 1) % table.length - entry = table(h) + curEntry = table(h) } - entry + curEntry } - /** Add entry if not yet in table. - * @return Returns `true` if a new entry was added, `false` otherwise. + /** Add elem if not yet in table. + * @return Returns `true` if a new elem was added, `false` otherwise. */ - protected def addEntry(elem: A) : Boolean = { - var h = index(elemHashCode(elem)) - var entry = table(h) - while (null != entry) { - if (entry == elem) return false + protected def addElem(elem: A) : Boolean = { + addEntry(elemToEntry(elem)) + } + + /** + * Add an entry (an elem converted to an entry via elemToEntry) if not yet in + * table. + * @return Returns `true` if a new elem was added, `false` otherwise. + */ + protected def addEntry(newEntry : AnyRef) : Boolean = { + var h = index(newEntry.hashCode) + var curEntry = table(h) + while (null != curEntry) { + if (curEntry == newEntry) return false h = (h + 1) % table.length - entry = table(h) + curEntry = table(h) //Statistics.collisions += 1 } - table(h) = elem.asInstanceOf[AnyRef] + table(h) = newEntry tableSize = tableSize + 1 nnSizeMapAdd(h) if (tableSize >= threshold) growTable() true + } - /** Removes an entry from the hash table, returning an option value with the element, or `None` if it didn't exist. */ - protected def removeEntry(elem: A) : Option[A] = { + /** Removes an elem from the hash table, returning an option value with the element, or `None` if it didn't exist. */ + protected def removeElem(elem: A) : Option[A] = { if (tableDebug) checkConsistent() def precedes(i: Int, j: Int) = { val d = table.length >> 1 if (i <= j) j - i < d else i - j > d } - var h = index(elemHashCode(elem)) - var entry = table(h) - while (null != entry) { - if (entry == elem) { + val removalEntry = elemToEntry(elem) + var h = index(removalEntry.hashCode) + var curEntry = table(h) + while (null != curEntry) { + if (curEntry == removalEntry) { var h0 = h var h1 = (h0 + 1) % table.length while (null != table(h1)) { - val h2 = index(elemHashCode(table(h1).asInstanceOf[A])) + val h2 = index(table(h1).hashCode) //Console.println("shift at "+h1+":"+table(h1)+" with h2 = "+h2+"? "+(h2 != h1)+precedes(h2, h0)+table.length) if (h2 != h1 && precedes(h2, h0)) { //Console.println("shift "+h1+" to "+h0+"!") @@ -176,10 +188,10 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] { tableSize -= 1 nnSizeMapRemove(h0) if (tableDebug) checkConsistent() - return Some(entry.asInstanceOf[A]) + return Some(entryToElem(curEntry)) } h = (h + 1) % table.length - entry = table(h) + curEntry = table(h) } None } @@ -191,7 +203,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] { i < table.length } def next(): A = - if (hasNext) { i += 1; table(i - 1).asInstanceOf[A] } + if (hasNext) { i += 1; entryToElem(table(i - 1)) } else Iterator.empty.next } @@ -205,7 +217,7 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] { var i = 0 while (i < oldtable.length) { val entry = oldtable(i) - if (null != entry) addEntry(entry.asInstanceOf[A]) + if (null != entry) addEntry(entry) i += 1 } if (tableDebug) checkConsistent() @@ -213,9 +225,10 @@ trait FlatHashTable[A] extends FlatHashTable.HashUtils[A] { private def checkConsistent() { for (i <- 0 until table.length) - if (table(i) != null && !containsEntry(table(i).asInstanceOf[A])) + if (table(i) != null && !containsElem(entryToElem(table(i)))) assert(false, i+" "+table(i)+" "+table.mkString) } + /* Size map handling code */ @@ -358,6 +371,10 @@ private[collection] object FlatHashTable { final def seedGenerator = new ThreadLocal[scala.util.Random] { override def initialValue = new scala.util.Random } + + val NullSentinel = new AnyRef { + override def hashCode = 0 + } /** The load factor for the hash table; must be < 500 (0.5) */ @@ -386,10 +403,6 @@ private[collection] object FlatHashTable { // so that: protected final def sizeMapBucketSize = 1 << sizeMapBucketBitSize - protected def elemHashCode(elem: A) = - if (elem == null) throw new IllegalArgumentException("Flat hash tables cannot contain null elements.") - else elem.hashCode() - protected final def improve(hcode: Int, seed: Int) = { //var h: Int = hcode + ~(hcode << 9) //h = h ^ (h >>> 14) @@ -404,6 +417,19 @@ private[collection] object FlatHashTable { val rotated = (improved >>> rotation) | (improved << (32 - rotation)) rotated } + + /** + * Elems have type A, but we store AnyRef in the table. Plus we need to deal with + * null elems, which need to be stored as NullSentinel + */ + protected final def elemToEntry(elem : A) : AnyRef = + if (null == elem) NullSentinel else elem.asInstanceOf[AnyRef] + + /** + * Does the inverse translation of elemToEntry + */ + protected final def entryToElem(entry : AnyRef) : A = + (if (NullSentinel eq entry) null else entry).asInstanceOf[A] } } diff --git a/src/library/scala/collection/mutable/HashSet.scala b/src/library/scala/collection/mutable/HashSet.scala index 74f2a6c762..d2e4bd78b3 100644 --- a/src/library/scala/collection/mutable/HashSet.scala +++ b/src/library/scala/collection/mutable/HashSet.scala @@ -55,17 +55,17 @@ extends AbstractSet[A] override def size: Int = tableSize - def contains(elem: A): Boolean = containsEntry(elem) + def contains(elem: A): Boolean = containsElem(elem) - def += (elem: A): this.type = { addEntry(elem); this } + def += (elem: A): this.type = { addElem(elem); this } - def -= (elem: A): this.type = { removeEntry(elem); this } + def -= (elem: A): this.type = { removeElem(elem); this } override def par = new ParHashSet(hashTableContents) - override def add(elem: A): Boolean = addEntry(elem) + override def add(elem: A): Boolean = addElem(elem) - override def remove(elem: A): Boolean = removeEntry(elem).isDefined + override def remove(elem: A): Boolean = removeElem(elem).isDefined override def clear() { clearTable() } @@ -75,8 +75,8 @@ extends AbstractSet[A] var i = 0 val len = table.length while (i < len) { - val elem = table(i) - if (elem ne null) f(elem.asInstanceOf[A]) + val curEntry = table(i) + if (curEntry ne null) f(entryToElem(curEntry)) i += 1 } } diff --git a/src/library/scala/collection/parallel/mutable/ParHashSet.scala b/src/library/scala/collection/parallel/mutable/ParHashSet.scala index f1377fb0a7..68a5f20cb5 100644 --- a/src/library/scala/collection/parallel/mutable/ParHashSet.scala +++ b/src/library/scala/collection/parallel/mutable/ParHashSet.scala @@ -60,18 +60,18 @@ extends ParSet[T] override def seq = new scala.collection.mutable.HashSet(hashTableContents) def +=(elem: T) = { - addEntry(elem) + addElem(elem) this } def -=(elem: T) = { - removeEntry(elem) + removeElem(elem) this } override def stringPrefix = "ParHashSet" - def contains(elem: T) = containsEntry(elem) + def contains(elem: T) = containsElem(elem) def splitter = new ParHashSetIterator(0, table.length, size) @@ -117,22 +117,23 @@ object ParHashSet extends ParSetFactory[ParHashSet] { private[mutable] abstract class ParHashSetCombiner[T](private val tableLoadFactor: Int) -extends scala.collection.parallel.BucketCombiner[T, ParHashSet[T], Any, ParHashSetCombiner[T]](ParHashSetCombiner.numblocks) +extends scala.collection.parallel.BucketCombiner[T, ParHashSet[T], AnyRef, ParHashSetCombiner[T]](ParHashSetCombiner.numblocks) with scala.collection.mutable.FlatHashTable.HashUtils[T] { //self: EnvironmentPassingCombiner[T, ParHashSet[T]] => private val nonmasklen = ParHashSetCombiner.nonmasklength private val seedvalue = 27 def +=(elem: T) = { + val entry = elemToEntry(elem) sz += 1 - val hc = improve(elemHashCode(elem), seedvalue) + val hc = improve(entry.hashCode, seedvalue) val pos = hc >>> nonmasklen if (buckets(pos) eq null) { // initialize bucket - buckets(pos) = new UnrolledBuffer[Any] + buckets(pos) = new UnrolledBuffer[AnyRef] } // add to bucket - buckets(pos) += elem + buckets(pos) += entry this } @@ -146,7 +147,7 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] { val table = new AddingFlatHashTable(size, tableLoadFactor, seedvalue) val (inserted, leftovers) = combinerTaskSupport.executeAndWaitResult(new FillBlocks(buckets, table, 0, buckets.length)) var leftinserts = 0 - for (elem <- leftovers) leftinserts += table.insertEntry(0, table.tableLength, elem.asInstanceOf[T]) + for (entry <- leftovers) leftinserts += table.insertEntry(0, table.tableLength, entry) table.setSize(leftinserts + inserted) table.hashTableContents } @@ -160,8 +161,8 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] { for { buffer <- buckets; if buffer ne null; - elem <- buffer - } addEntry(elem.asInstanceOf[T]) + entry <- buffer + } addEntry(entry) } tbl.hashTableContents } @@ -188,7 +189,7 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] { def setSize(sz: Int) = tableSize = sz /** - * The elements are added using the `insertEntry` method. This method accepts three + * The elements are added using the `insertElem` method. This method accepts three * arguments: * * @param insertAt where to add the element (set to -1 to use its hashcode) @@ -205,17 +206,17 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] { * If the element is already present in the hash table, it is not added, and this method * returns 0. If the element is added, it returns 1. */ - def insertEntry(insertAt: Int, comesBefore: Int, elem: T): Int = { + def insertEntry(insertAt: Int, comesBefore: Int, newEntry : AnyRef): Int = { var h = insertAt - if (h == -1) h = index(elemHashCode(elem)) - var entry = table(h) - while (null != entry) { - if (entry == elem) return 0 + if (h == -1) h = index(newEntry.hashCode) + var curEntry = table(h) + while (null != curEntry) { + if (curEntry == newEntry) return 0 h = h + 1 // we *do not* do `(h + 1) % table.length` here, because we'll never overflow!! if (h >= comesBefore) return -1 - entry = table(h) + curEntry = table(h) } - table(h) = elem.asInstanceOf[AnyRef] + table(h) = newEntry // this is incorrect since we set size afterwards anyway and a counter // like this would not even work: @@ -232,13 +233,13 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] { /* tasks */ - class FillBlocks(buckets: Array[UnrolledBuffer[Any]], table: AddingFlatHashTable, val offset: Int, val howmany: Int) - extends Task[(Int, UnrolledBuffer[Any]), FillBlocks] { - var result = (Int.MinValue, new UnrolledBuffer[Any]); - def leaf(prev: Option[(Int, UnrolledBuffer[Any])]) { + class FillBlocks(buckets: Array[UnrolledBuffer[AnyRef]], table: AddingFlatHashTable, val offset: Int, val howmany: Int) + extends Task[(Int, UnrolledBuffer[AnyRef]), FillBlocks] { + var result = (Int.MinValue, new UnrolledBuffer[AnyRef]); + def leaf(prev: Option[(Int, UnrolledBuffer[AnyRef])]) { var i = offset var totalinserts = 0 - var leftover = new UnrolledBuffer[Any]() + var leftover = new UnrolledBuffer[AnyRef]() while (i < (offset + howmany)) { val (inserted, intonextblock) = fillBlock(i, buckets(i), leftover) totalinserts += inserted @@ -250,11 +251,11 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] { private val blocksize = table.tableLength >> ParHashSetCombiner.discriminantbits private def blockStart(block: Int) = block * blocksize private def nextBlockStart(block: Int) = (block + 1) * blocksize - private def fillBlock(block: Int, elems: UnrolledBuffer[Any], leftovers: UnrolledBuffer[Any]): (Int, UnrolledBuffer[Any]) = { + private def fillBlock(block: Int, elems: UnrolledBuffer[AnyRef], leftovers: UnrolledBuffer[AnyRef]): (Int, UnrolledBuffer[AnyRef]) = { val beforePos = nextBlockStart(block) // store the elems - val (elemsIn, elemsLeft) = if (elems != null) insertAll(-1, beforePos, elems) else (0, UnrolledBuffer[Any]()) + val (elemsIn, elemsLeft) = if (elems != null) insertAll(-1, beforePos, elems) else (0, UnrolledBuffer[AnyRef]()) // store the leftovers val (leftoversIn, leftoversLeft) = insertAll(blockStart(block), beforePos, leftovers) @@ -262,8 +263,8 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] { // return the no. of stored elements tupled with leftovers (elemsIn + leftoversIn, elemsLeft concat leftoversLeft) } - private def insertAll(atPos: Int, beforePos: Int, elems: UnrolledBuffer[Any]): (Int, UnrolledBuffer[Any]) = { - val leftovers = new UnrolledBuffer[Any] + private def insertAll(atPos: Int, beforePos: Int, elems: UnrolledBuffer[AnyRef]): (Int, UnrolledBuffer[AnyRef]) = { + val leftovers = new UnrolledBuffer[AnyRef] var inserted = 0 var unrolled = elems.headPtr @@ -273,10 +274,10 @@ with scala.collection.mutable.FlatHashTable.HashUtils[T] { val chunkarr = unrolled.array val chunksz = unrolled.size while (i < chunksz) { - val elem = chunkarr(i) - val res = t.insertEntry(atPos, beforePos, elem.asInstanceOf[T]) + val entry = chunkarr(i) + val res = t.insertEntry(atPos, beforePos, entry) if (res >= 0) inserted += res - else leftovers += elem + else leftovers += entry i += 1 } i = 0 -- cgit v1.2.3