summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/library/scala/collection/immutable/HashMap.scala67
-rw-r--r--test/files/run/t5879.check16
-rw-r--r--test/files/run/t5879.scala74
3 files changed, 142 insertions, 15 deletions
diff --git a/src/library/scala/collection/immutable/HashMap.scala b/src/library/scala/collection/immutable/HashMap.scala
index 13a0febfee..05529761d3 100644
--- a/src/library/scala/collection/immutable/HashMap.scala
+++ b/src/library/scala/collection/immutable/HashMap.scala
@@ -74,11 +74,22 @@ class HashMap[A, +B] extends AbstractMap[A, B]
private[collection] def computeHash(key: A) = improve(elemHashCode(key))
- protected type Merger[B1] = ((A, B1), (A, B1)) => (A, B1)
+ protected type MergeFunction[A1, B1] = ((A1, B1), (A1, B1)) => (A1, B1);
+
+ import HashMap.Merger
+
+ protected def liftMerger[A1, B1](mergef: MergeFunction[A1, B1]): Merger[A1, B1] = if (mergef == null) null else new Merger[A1, B1] {
+ self =>
+ def apply(kv1: (A1, B1), kv2: (A1, B1)): (A1, B1) = mergef(kv1, kv2)
+ val invert: Merger[A1, B1] = new Merger[A1, B1] {
+ def apply(kv1: (A1, B1), kv2: (A1, B1)): (A1, B1) = mergef(kv2, kv1)
+ def invert: Merger[A1, B1] = self
+ }
+ }
private[collection] def get0(key: A, hash: Int, level: Int): Option[B] = None
- private[collection] def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[B1]): HashMap[A, B1] =
+ private[collection] def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[A, B1]): HashMap[A, B1] =
new HashMap.HashMap1(key, hash, value, kv)
protected def removed0(key: A, hash: Int, level: Int): HashMap[A, B] = this
@@ -87,9 +98,25 @@ class HashMap[A, +B] extends AbstractMap[A, B]
def split: Seq[HashMap[A, B]] = Seq(this)
- def merge[B1 >: B](that: HashMap[A, B1], merger: Merger[B1] = null): HashMap[A, B1] = merge0(that, 0, merger)
-
- protected def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[B1]): HashMap[A, B1] = that
+ @deprecated("Use the `merged` method instead.", "2.10.0")
+ def merge[B1 >: B](that: HashMap[A, B1], mergef: MergeFunction[A, B1] = null): HashMap[A, B1] = merge0(that, 0, liftMerger(mergef))
+
+ /** Creates a new map which is the merge of this and the argument hash map.
+ *
+ * Uses the specified collision resolution function if two keys are the same.
+ * The collision resolution function will always take the first argument from
+ * `this` hash map and the second from `that`.
+ *
+ * The `merged` method is on average more performant than doing a traversal and reconstructing a
+ * new immutable hash map from scratch, or `++`.
+ *
+ * @tparam B1 the value type of the other hash map
+ * @param that the other hash map
+ * @param mergef the merge function or null if the first key-value pair is to be picked
+ */
+ def merged[B1 >: B](that: HashMap[A, B1])(mergef: MergeFunction[A, B1]): HashMap[A, B1] = merge0(that, 0, liftMerger(mergef))
+
+ protected def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[A, B1]): HashMap[A, B1] = that
override def par = ParHashMap.fromTrie(this)
@@ -103,6 +130,13 @@ class HashMap[A, +B] extends AbstractMap[A, B]
* @since 2.3
*/
object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
+
+ private[immutable] abstract class Merger[A, B] {
+ def apply(kv1: (A, B), kv2: (A, B)): (A, B)
+ def invert: Merger[A, B]
+ }
+
+
/** $mapCanBuildFromInfo */
implicit def canBuildFrom[A, B]: CanBuildFrom[Coll, (A, B), HashMap[A, B]] = new MapCanBuildFrom[A, B]
def empty[A, B]: HashMap[A, B] = EmptyHashMap.asInstanceOf[HashMap[A, B]]
@@ -136,12 +170,15 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
// }
// }
- override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[B1]): HashMap[A, B1] =
+ override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[A, B1]): HashMap[A, B1] =
if (hash == this.hash && key == this.key ) {
if (merger eq null) {
- if(this.value.asInstanceOf[AnyRef] eq value.asInstanceOf[AnyRef]) this
+ if (this.value.asInstanceOf[AnyRef] eq value.asInstanceOf[AnyRef]) this
else new HashMap1(key, hash, value, kv)
- } else new HashMap1(key, hash, value, merger(this.kv, kv))
+ } else {
+ val nkv = merger(this.kv, kv)
+ new HashMap1(nkv._1, hash, nkv._2, nkv)
+ }
} else {
var thatindex = (hash >>> level) & 0x1f
var thisindex = (this.hash >>> level) & 0x1f
@@ -180,8 +217,8 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
override def foreach[U](f: ((A, B)) => U): Unit = f(ensurePair)
// this method may be called multiple times in a multithreaded environment, but that's ok
private[HashMap] def ensurePair: (A,B) = if (kv ne null) kv else { kv = (key, value); kv }
- protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[B1]): HashMap[A, B1] = {
- that.updated0(key, hash, level, value, kv, merger)
+ protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[A, B1]): HashMap[A, B1] = {
+ that.updated0(key, hash, level, value, kv, if (merger ne null) merger.invert else null)
}
}
@@ -193,7 +230,7 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
override def get0(key: A, hash: Int, level: Int): Option[B] =
if (hash == this.hash) kvs.get(key) else None
- override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[B1]): HashMap[A, B1] =
+ override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[A, B1]): HashMap[A, B1] =
if (hash == this.hash) {
if ((merger eq null) || !kvs.contains(key)) new HashMapCollision1(hash, kvs.updated(key, value))
else new HashMapCollision1(hash, kvs + merger((key, kvs(key)), kv))
@@ -221,7 +258,7 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
def newhm(lm: ListMap[A, B @uV]) = new HashMapCollision1(hash, lm)
List(newhm(x), newhm(y))
}
- protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[B1]): HashMap[A, B1] = {
+ protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[A, B1]): HashMap[A, B1] = {
// this can be made more efficient by passing the entire ListMap at once
var m = that
for (p <- kvs) m = m.updated0(p._1, this.hash, level, p._2, p, merger)
@@ -268,7 +305,7 @@ object HashMap extends ImmutableMapFactory[HashMap] with BitOperations.Int {
None
}
- override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[B1]): HashMap[A, B1] = {
+ override def updated0[B1 >: B](key: A, hash: Int, level: Int, value: B1, kv: (A, B1), merger: Merger[A, B1]): HashMap[A, B1] = {
val index = (hash >>> level) & 0x1f
val mask = (1 << index)
val offset = Integer.bitCount(bitmap & (mask-1))
@@ -380,7 +417,7 @@ time { mNew.iterator.foreach( p => ()) }
} else elems(0).split
}
- protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[B1]): HashMap[A, B1] = that match {
+ protected override def merge0[B1 >: B](that: HashMap[A, B1], level: Int, merger: Merger[A, B1]): HashMap[A, B1] = that match {
case hm: HashMap1[_, _] =>
this.updated0(hm.key, hm.hash, level, hm.value.asInstanceOf[B1], hm.kv, merger)
case hm: HashTrieMap[_, _] =>
@@ -440,7 +477,7 @@ time { mNew.iterator.foreach( p => ()) }
}
new HashTrieMap[A, B1](this.bitmap | that.bitmap, merged, totalelems)
- case hm: HashMapCollision1[_, _] => that.merge0(this, level, merger)
+ case hm: HashMapCollision1[_, _] => that.merge0(this, level, if (merger ne null) merger.invert else null)
case hm: HashMap[_, _] => this
case _ => sys.error("section supposed to be unreachable.")
}
diff --git a/test/files/run/t5879.check b/test/files/run/t5879.check
new file mode 100644
index 0000000000..b6cbda35a7
--- /dev/null
+++ b/test/files/run/t5879.check
@@ -0,0 +1,16 @@
+Map(1 -> 1)
+1
+Map(1 -> 1)
+1
+(1,1)
+Map(1 -> 1)
+1
+(1,1)
+Map(1 -> 1)
+1
+(1,2)
+Map(1 -> 2)
+2
+(1,2)
+Map(1 -> 2)
+2 \ No newline at end of file
diff --git a/test/files/run/t5879.scala b/test/files/run/t5879.scala
new file mode 100644
index 0000000000..e1c07fc4c2
--- /dev/null
+++ b/test/files/run/t5879.scala
@@ -0,0 +1,74 @@
+import collection.immutable.HashMap
+
+
+object Test {
+
+ def main(args: Array[String]) {
+ resolveDefault()
+ resolveFirst()
+ resolveSecond()
+ resolveMany()
+ }
+
+ def resolveDefault() {
+ val a = HashMap(1 -> "1")
+ val b = HashMap(1 -> "2")
+
+ val r = a.merged(b)(null)
+ println(r)
+ println(r(1))
+
+ val rold = a.merge(b)
+ println(rold)
+ println(rold(1))
+ }
+
+ def resolveFirst() {
+ val a = HashMap(1 -> "1")
+ val b = HashMap(1 -> "2")
+ def collision(a: (Int, String), b: (Int, String)) = {
+ println(a)
+ a
+ }
+
+ val r = a.merged(b) { collision }
+ println(r)
+ println(r(1))
+
+ val rold = a.merge(b, collision)
+ println(rold)
+ println(rold(1))
+ }
+
+ def resolveSecond() {
+ val a = HashMap(1 -> "1")
+ val b = HashMap(1 -> "2")
+ def collision(a: (Int, String), b: (Int, String)) = {
+ println(b)
+ b
+ }
+
+ val r = a.merged(b) { collision }
+ println(r)
+ println(r(1))
+
+ val rold = a.merge(b, collision)
+ println(rold)
+ println(rold(1))
+ }
+
+ def resolveMany() {
+ val a = HashMap((0 until 100) zip (0 until 100): _*)
+ val b = HashMap((0 until 100) zip (100 until 200): _*)
+ def collision(a: (Int, Int), b: (Int, Int)) = {
+ (a._1, a._2 + b._2)
+ }
+
+ val r = a.merged(b) { collision }
+ for ((k, v) <- r) assert(v == 100 + 2 * k, (k, v))
+
+ val rold = a.merge(b, collision)
+ for ((k, v) <- r) assert(v == 100 + 2 * k, (k, v))
+ }
+
+}