From c8ddf01621417ef1e5f07682a360ff1b1ebfd416 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Sun, 9 Jan 2011 06:30:35 +0000 Subject: Closes #3984 by the most arduous and indirect r... Closes #3984 by the most arduous and indirect route imaginable: abstracts the common code out of the TrieIterators in HashMap and HashSet. When I raised this flag to find out if anyone would open fire, all was quiet on the western front. Although I wouldn't want to write code like this as an everyday thing, I think it serves as a nice showcase for some of the abstraction challenges we're up against: performance looks the same and I will never again have to fix the same bug in two places. Review by rompf. --- .../scala/collection/immutable/HashMap.scala | 178 ++++----------------- 1 file changed, 30 insertions(+), 148 deletions(-) (limited to 'src/library/scala/collection/immutable/HashMap.scala') diff --git a/src/library/scala/collection/immutable/HashMap.scala b/src/library/scala/collection/immutable/HashMap.scala index a21e71158e..93023a3097 100644 --- a/src/library/scala/collection/immutable/HashMap.scala +++ b/src/library/scala/collection/immutable/HashMap.scala @@ -12,12 +12,9 @@ package scala.collection package immutable import generic._ -import annotation.unchecked.uncheckedVariance - - +import annotation.unchecked.{ uncheckedVariance=> uV } import parallel.immutable.ParHashMap - /** This class implements immutable maps using a hash trie. * * '''Note:''' the builder of a hash map returns specialized representations EmptyMap,Map1,..., Map4 @@ -96,7 +93,6 @@ class HashMap[A, +B] extends Map[A,B] with MapLike[A, B, HashMap[A, B]] with Par private type C = (A, B) override def toParMap[D, E](implicit ev: C <:< (D, E)) = par.asInstanceOf[ParHashMap[D, E]] - } /** $factoryInfo @@ -117,7 +113,7 @@ object HashMap extends ImmutableMapFactory[HashMap] { // TODO: add HashMap2, HashMap3, ... - class HashMap1[A,+B](private[HashMap] var key: A, private[HashMap] var hash: Int, private[collection] var value: (B @uncheckedVariance), private[collection] var kv: (A,B @uncheckedVariance)) extends HashMap[A,B] { + class HashMap1[A,+B](private[HashMap] var key: A, private[HashMap] var hash: Int, private[collection] var value: (B @uV), private[collection] var kv: (A,B @uV)) extends HashMap[A,B] { override def size = 1 private[collection] def getKey = key @@ -188,7 +184,7 @@ object HashMap extends ImmutableMapFactory[HashMap] { } } - private[collection] class HashMapCollision1[A,+B](private[HashMap] var hash: Int, var kvs: ListMap[A,B @uncheckedVariance]) extends HashMap[A,B] { + private[collection] class HashMapCollision1[A,+B](private[HashMap] var hash: Int, var kvs: ListMap[A,B @uV]) extends HashMap[A,B] { override def size = kvs.size override def get0(key: A, hash: Int, level: Int): Option[B] = @@ -219,7 +215,7 @@ object HashMap extends ImmutableMapFactory[HashMap] { override def foreach[U](f: ((A, B)) => U): Unit = kvs.foreach(f) override def split: Seq[HashMap[A, B]] = { val (x, y) = kvs.splitAt(kvs.size / 2) - def newhm(lm: ListMap[A, B @uncheckedVariance]) = new HashMapCollision1(hash, lm) + def newhm(lm: ListMap[A, B @uV]) = new HashMapCollision1(hash, lm) List(newhm(x), newhm(y)) } protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[B1]): HashMap[A, B1] = { @@ -230,7 +226,7 @@ object HashMap extends ImmutableMapFactory[HashMap] { } } - class HashTrieMap[A,+B](private[HashMap] var bitmap: Int, private[collection] var elems: Array[HashMap[A,B @uncheckedVariance]], + class HashTrieMap[A,+B](private[HashMap] var bitmap: Int, private[collection] var elems: Array[HashMap[A,B @uV]], private[HashMap] var size0: Int) extends HashMap[A,B] { /* def this (level: Int, m1: HashMap1[A,B], m2: HashMap1[A,B]) = { @@ -316,7 +312,7 @@ object HashMap extends ImmutableMapFactory[HashMap] { } } - override def iterator = new TrieIterator[A, B](elems) + override def iterator: Iterator[(A, B)] = new CovariantTrieIterator[A, B](elems) /* @@ -468,148 +464,35 @@ time { mNew.iterator.foreach( p => ()) } case hm: HashMap[_, _] => this case _ => system.error("section supposed to be unreachable.") } - } - class TrieIterator[A, +B](elems: Array[HashMap[A, B]]) extends Iterator[(A, B)] { - protected var depth = 0 - protected var arrayStack: Array[Array[HashMap[A, B @uncheckedVariance]]] = new Array[Array[HashMap[A,B]]](6) - protected var posStack = new Array[Int](6) - - protected var arrayD: Array[HashMap[A, B @uncheckedVariance]] = elems - protected var posD = 0 - - protected var subIter: Iterator[(A, B @uncheckedVariance)] = null // to traverse collision nodes - - def dupIterator: TrieIterator[A, B] = { - val t = new TrieIterator(elems) - t.depth = depth - t.arrayStack = arrayStack - t.posStack = posStack - t.arrayD = arrayD - t.posD = posD - t.subIter = subIter - t - } - - def hasNext = (subIter ne null) || depth >= 0 + class CovariantTrieIterator[A, +B](elems: Array[HashMap[A, B]]) extends Iterator[(A, B)] { + private[this] val it = new TrieIterator[A, B](elems) + def next = it.next + def hasNext = it.hasNext + } - def next: (A,B) = { - if (subIter ne null) { - val el = subIter.next - if (!subIter.hasNext) - subIter = null - el - } else - next0(arrayD, posD) - } + class TrieIterator[A, B](elems: Array[HashMap[A, B]]) extends TrieIteratorBase[(A, B), HashMap[A, B]](elems) { + import TrieIteratorBase._ - @scala.annotation.tailrec private[this] def next0(elems: Array[HashMap[A,B]], i: Int): (A,B) = { - if (i == elems.length-1) { // reached end of level, pop stack - depth -= 1 - if (depth >= 0) { - arrayD = arrayStack(depth) - posD = posStack(depth) - arrayStack(depth) = null - } else { - arrayD = null - posD = 0 - } - } else - posD += 1 + type This = TrieIterator[A, B] + private[immutable] def recreateIterator() = new TrieIterator(elems) + private[immutable] type ContainerType = HashMap1[A, B] + private[immutable] type TrieType = HashTrieMap[A, B] + private[immutable] type CollisionType = HashMapCollision1[A, B] - elems(i) match { - case m: HashTrieMap[_, _] => // push current pos onto stack and descend - if (depth >= 0) { - arrayStack(depth) = arrayD - posStack(depth) = posD - } - depth += 1 - val elems = m.elems.asInstanceOf[Array[HashMap[A, B]]] - arrayD = elems - posD = 0 - next0(elems, 0) - case m: HashMap1[_, _] => m.ensurePair - case m => - subIter = m.iterator - subIter.next - } + private[immutable] def determineType(x: HashMap[A, B]) = x match { + case _: HashMap1[_, _] => CONTAINER_TYPE + case _: HashTrieMap[_, _] => TRIE_TYPE + case _: HashMapCollision1[_, _] => COLLISION_TYPE } - // assumption: contains 2 or more elements - // splits this iterator into 2 iterators - // returns the 1st iterator, its number of elements, and the second iterator - def split: ((Iterator[(A, B)], Int), Iterator[(A, B)]) = { - def collisionToArray(c: HashMapCollision1[_, _]) = - c.asInstanceOf[HashMapCollision1[A, B]].kvs.toArray map { HashMap() + _ } - def arrayToIterators(arr: Array[HashMap[A, B]]) = { - val (fst, snd) = arr.splitAt(arr.length / 2) - val szsnd = snd.foldLeft(0)(_ + _.size) - ((new TrieIterator(snd), szsnd), new TrieIterator(fst)) - } - def splitArray(ad: Array[HashMap[A, B]]): ((Iterator[(A, B)], Int), Iterator[(A, B)]) = if (ad.length > 1) { - arrayToIterators(ad) - } else ad(0) match { - case c: HashMapCollision1[a, b] => arrayToIterators(collisionToArray(c.asInstanceOf[HashMapCollision1[A, B]])) - case hm: HashTrieMap[a, b] => splitArray(hm.elems.asInstanceOf[Array[HashMap[A, B]]]) - } - - // 0) simple case: no elements have been iterated - simply divide arrayD - if (arrayD != null && depth == 0 && posD == 0) { - return splitArray(arrayD) - } - - // otherwise, some elements have been iterated over - // 1) collision case: if we have a subIter, we return subIter and elements after it - if (subIter ne null) { - val buff = subIter.toBuffer - subIter = null - ((buff.iterator, buff.length), this) - } else { - // otherwise find the topmost array stack element - if (depth > 0) { - // 2) topmost comes before (is not) arrayD - // steal a portion of top to create a new iterator - val topmost = arrayStack(0) - if (posStack(0) == arrayStack(0).length - 1) { - // 2a) only a single entry left on top - // this means we have to modify this iterator - pop topmost - val snd = Array(arrayStack(0).last) - val szsnd = snd(0).size - // modify this - pop - depth -= 1 - arrayStack = arrayStack.tail ++ Array[Array[HashMap[A, B]]](null) - posStack = posStack.tail ++ Array[Int](0) - // we know that `this` is not empty, since it had something on the arrayStack and arrayStack elements are always non-empty - ((new TrieIterator[A, B](snd), szsnd), this) - } else { - // 2b) more than a single entry left on top - val (fst, snd) = arrayStack(0).splitAt(arrayStack(0).length - (arrayStack(0).length - posStack(0) + 1) / 2) - arrayStack(0) = fst - val szsnd = snd.foldLeft(0)(_ + _.size) - ((new TrieIterator[A, B](snd), szsnd), this) - } - } else { - // 3) no topmost element (arrayD is at the top) - // steal a portion of it and update this iterator - if (posD == arrayD.length - 1) { - // 3a) positioned at the last element of arrayD - val arr: Array[HashMap[A, B]] = arrayD(posD) match { - case c: HashMapCollision1[a, b] => collisionToArray(c).asInstanceOf[Array[HashMap[A, B]]] - case ht: HashTrieMap[_, _] => ht.asInstanceOf[HashTrieMap[A, B]].elems - case _ => system.error("cannot divide single element") - } - arrayToIterators(arr) - } else { - // 3b) arrayD has more free elements - val (fst, snd) = arrayD.splitAt(arrayD.length - (arrayD.length - posD + 1) / 2) - arrayD = fst - val szsnd = snd.foldLeft(0)(_ + _.size) - ((new TrieIterator[A, B](snd), szsnd), this) - } - } - } - } + private[immutable] def getElem(cc: ContainerType) = cc.ensurePair + private[immutable] def getElems(t: TrieType) = t.elems + private[immutable] def collisionToArray(c: CollisionType) = c.kvs map (x => HashMap(x)) toArray + private[immutable] def newThisType(xs: Array[HashMap[A, B]]) = new TrieIterator(xs) + private[immutable] def newDeepArray(size: Int) = new Array[Array[HashMap[A, B]]](size) + private[immutable] def newSingleArray(el: HashMap[A, B]) = Array(el) } private def check[K](x: HashMap[K, _], y: HashMap[K, _], xy: HashMap[K, _]) = { // TODO remove this debugging helper @@ -631,7 +514,8 @@ time { mNew.iterator.foreach( p => ()) } } else true } - @SerialVersionUID(2L) private class SerializationProxy[A,B](@transient private var orig: HashMap[A, B]) extends Serializable { + @SerialVersionUID(2L) + private class SerializationProxy[A,B](@transient private var orig: HashMap[A, B]) extends Serializable { private def writeObject(out: java.io.ObjectOutputStream) { val s = orig.size out.writeInt(s) @@ -653,6 +537,4 @@ time { mNew.iterator.foreach( p => ()) } private def readResolve(): AnyRef = orig } - } - -- cgit v1.2.3