From 90cd1f9ccd272e436b23af70aa1465c673aab25e Mon Sep 17 00:00:00 2001 From: Aleksandar Prokopec Date: Wed, 6 Jun 2012 17:29:02 +0200 Subject: Fix SI-5879. Fix a bug where a key in an immutable hash map have the corresponding value different in the iteration than when doing lookup. This use to happen after calling `merge`. Fix the order in which a key-value pair appears in the collision resolution function - the first argument always comes from the `this` hash map. Deprecate `merge` in favour of `merged`, as this is a pure method. As an added benefit, the syntax for invoking `merge` is now nicer. --- .../scala/collection/immutable/HashMap.scala | 67 +++++++++++++++++----- 1 file changed, 52 insertions(+), 15 deletions(-) (limited to 'src') diff --git a/src/library/scala/collection/immutable/HashMap.scala b/src/library/scala/collection/immutable/HashMap.scala index 13a0febfee..05529761d3 100644 --- a/src/library/scala/collection/immutable/HashMap.scala +++ b/src/library/scala/collection/immutable/HashMap.scala @@ -74,11 +74,22 @@ class HashMap[A, +B] extends AbstractMap[A, B] private[collection] def computeHash(key: A) = improve(elemHashCode(key)) - protected type Merger[B1] = ((A, B1), (A, B1)) => (A, B1) + protected type MergeFunction[A1, B1] = ((A1, B1), (A1, B1)) => (A1, B1); + + import HashMap.Merger + + protected def liftMerger[A1, B1](mergef: MergeFunction[A1, B1]): Merger[A1, B1] = if (mergef == null) null else new Merger[A1, B1] { + self => + def apply(kv1: (A1, B1), kv2: (A1, B1)): (A1, B1) = mergef(kv1, kv2) + val invert: Merger[A1, B1] = new Merger[A1, B1] { + def apply(kv1: (A1, B1), kv2: (A1, B1)): (A1, B1) = mergef(kv2, kv1) + def invert: Merger[A1, B1] = self + } + } private[collection] def get0(key: A, hash: Int, level: Int): Option[B] = None - private[collection] def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[B1]): HashMap[A, B1] = + private[collection] def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[A, B1]): HashMap[A, B1] = new HashMap.HashMap1(key, hash, value, kv) protected def removed0(key: A, hash: Int, level: Int): HashMap[A, B] = this @@ -87,9 +98,25 @@ class HashMap[A, +B] extends AbstractMap[A, B] def split: Seq[HashMap[A, B]] = Seq(this) - def merge[B1 >: B](that: HashMap[A, B1], merger: Merger[B1] = null): HashMap[A, B1] = merge0(that, 0, merger) - - protected def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[B1]): HashMap[A, B1] = that + @deprecated("Use the `merged` method instead.", "2.10.0") + def merge[B1 >: B](that: HashMap[A, B1], mergef: MergeFunction[A, B1] = null): HashMap[A, B1] = merge0(that, 0, liftMerger(mergef)) + + /** Creates a new map which is the merge of this and the argument hash map. + * + * Uses the specified collision resolution function if two keys are the same. + * The collision resolution function will always take the first argument from + * `this` hash map and the second from `that`. + * + * The `merged` method is on average more performant than doing a traversal and reconstructing a + * new immutable hash map from scratch, or `++`. + * + * @tparam B1 the value type of the other hash map + * @param that the other hash map + * @param mergef the merge function or null if the first key-value pair is to be picked + */ + def merged[B1 >: B](that: HashMap[A, B1])(mergef: MergeFunction[A, B1]): HashMap[A, B1] = merge0(that, 0, liftMerger(mergef)) + + protected def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[A, B1]): HashMap[A, B1] = that override def par = ParHashMap.fromTrie(this) @@ -103,6 +130,13 @@ class HashMap[A, +B] extends AbstractMap[A, B] * @since 2.3 */ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int { + + private[immutable] abstract class Merger[A, B] { + def apply(kv1: (A, B), kv2: (A, B)): (A, B) + def invert: Merger[A, B] + } + + /** $mapCanBuildFromInfo */ implicit def canBuildFrom[A, B]: CanBuildFrom[Coll, (A, B), HashMap[A, B]] = new MapCanBuildFrom[A, B] def empty[A, B]: HashMap[A, B] = EmptyHashMap.asInstanceOf[HashMap[A, B]] @@ -136,12 +170,15 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int { // } // } - override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[B1]): HashMap[A, B1] = + override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[A, B1]): HashMap[A, B1] = if (hash == this.hash && key == this.key ) { if (merger eq null) { - if(this.value.asInstanceOf[AnyRef] eq value.asInstanceOf[AnyRef]) this + if (this.value.asInstanceOf[AnyRef] eq value.asInstanceOf[AnyRef]) this else new HashMap1(key, hash, value, kv) - } else new HashMap1(key, hash, value, merger(this.kv, kv)) + } else { + val nkv = merger(this.kv, kv) + new HashMap1(nkv._1, hash, nkv._2, nkv) + } } else { var thatindex = (hash >>> level) & 0x1f var thisindex = (this.hash >>> level) & 0x1f @@ -180,8 +217,8 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int { override def foreach[U](f: ((A, B)) => U): Unit = f(ensurePair) // this method may be called multiple times in a multithreaded environment, but that's ok private[HashMap] def ensurePair: (A,B) = if (kv ne null) kv else { kv = (key, value); kv } - protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[B1]): HashMap[A, B1] = { - that.updated0(key, hash, level, value, kv, merger) + protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[A, B1]): HashMap[A, B1] = { + that.updated0(key, hash, level, value, kv, if (merger ne null) merger.invert else null) } } @@ -193,7 +230,7 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int { override def get0(key: A, hash: Int, level: Int): Option[B] = if (hash == this.hash) kvs.get(key) else None - override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[B1]): HashMap[A, B1] = + override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[A, B1]): HashMap[A, B1] = if (hash == this.hash) { if ((merger eq null) || !kvs.contains(key)) new HashMapCollision1(hash, kvs.updated(key, value)) else new HashMapCollision1(hash, kvs + merger((key, kvs(key)), kv)) @@ -221,7 +258,7 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int { def newhm(lm: ListMap[A, B @uV]) = new HashMapCollision1(hash, lm) List(newhm(x), newhm(y)) } - protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[B1]): HashMap[A, B1] = { + protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[A, B1]): HashMap[A, B1] = { // this can be made more efficient by passing the entire ListMap at once var m = that for (p <- kvs) m = m.updated0(p._1, this.hash, level, p._2, p, merger) @@ -268,7 +305,7 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int { None } - override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[B1]): HashMap[A, B1] = { + override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[A, B1]): HashMap[A, B1] = { val index = (hash >>> level) & 0x1f val mask = (1 << index) val offset = Integer.bitCount(bitmap & (mask-1)) @@ -380,7 +417,7 @@ time { mNew.iterator.foreach( p => ()) } } else elems(0).split } - protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[B1]): HashMap[A, B1] = that match { + protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[A, B1]): HashMap[A, B1] = that match { case hm: HashMap1[_, _] => this.updated0(hm.key, hm.hash, level, hm.value.asInstanceOf[B1], hm.kv, merger) case hm: HashTrieMap[_, _] => @@ -440,7 +477,7 @@ time { mNew.iterator.foreach( p => ()) } } new HashTrieMap[A, B1](this.bitmap | that.bitmap, merged, totalelems) - case hm: HashMapCollision1[_, _] => that.merge0(this, level, merger) + case hm: HashMapCollision1[_, _] => that.merge0(this, level, if (merger ne null) merger.invert else null) case hm: HashMap[_, _] => this case _ => sys.error("section supposed to be unreachable.") } -- cgit v1.2.3