summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/library/scala/collection/immutable/RedBlackTree.scala33
-rw-r--r--src/library/scala/collection/immutable/TreeMap.scala4
-rw-r--r--src/library/scala/collection/immutable/TreeSet.scala4
-rw-r--r--test/files/run/t5986.check15
-rw-r--r--test/files/run/t5986.scala36
-rw-r--r--test/files/scalacheck/redblacktree.scala4
6 files changed, 74 insertions, 22 deletions
diff --git a/src/library/scala/collection/immutable/RedBlackTree.scala b/src/library/scala/collection/immutable/RedBlackTree.scala
index 0f28c4997b..4b573511d1 100644
--- a/src/library/scala/collection/immutable/RedBlackTree.scala
+++ b/src/library/scala/collection/immutable/RedBlackTree.scala
@@ -43,7 +43,7 @@ object RedBlackTree {
}
def count(tree: Tree[_, _]) = if (tree eq null) 0 else tree.count
- def update[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(tree, k, v))
+ def update[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1, overwrite: Boolean)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(tree, k, v, overwrite))
def delete[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = blacken(del(tree, k))
def rangeImpl[A: Ordering, B](tree: Tree[A, B], from: Option[A], until: Option[A]): Tree[A, B] = (from, until) match {
case (Some(from), Some(until)) => this.range(tree, from, until)
@@ -122,17 +122,18 @@ object RedBlackTree {
else
mkTree(isBlack, x, xv, a, r)
}
- private[this] def upd[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = if (tree eq null) {
+ private[this] def upd[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1, overwrite: Boolean)(implicit ordering: Ordering[A]): Tree[A, B1] = if (tree eq null) {
RedTree(k, v, null, null)
} else {
val cmp = ordering.compare(k, tree.key)
- if (cmp < 0) balanceLeft(isBlackTree(tree), tree.key, tree.value, upd(tree.left, k, v), tree.right)
- else if (cmp > 0) balanceRight(isBlackTree(tree), tree.key, tree.value, tree.left, upd(tree.right, k, v))
- else mkTree(isBlackTree(tree), k, v, tree.left, tree.right)
+ if (cmp < 0) balanceLeft(isBlackTree(tree), tree.key, tree.value, upd(tree.left, k, v, overwrite), tree.right)
+ else if (cmp > 0) balanceRight(isBlackTree(tree), tree.key, tree.value, tree.left, upd(tree.right, k, v, overwrite))
+ else if (overwrite || k != tree.key) mkTree(isBlackTree(tree), k, v, tree.left, tree.right)
+ else tree
}
- // Based on Stefan Kahrs' Haskell version of Okasaki's Red&Black Trees
- // http://www.cse.unsw.edu.au/~dons/data/RedBlackTree.html
+ /* Based on Stefan Kahrs' Haskell version of Okasaki's Red&Black Trees
+ * http://www.cse.unsw.edu.au/~dons/data/RedBlackTree.html */
private[this] def del[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree eq null) null else {
def balance(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tl)) {
if (isRedTree(tr)) {
@@ -216,7 +217,7 @@ object RedBlackTree {
if (ordering.lt(tree.key, from)) return doFrom(tree.right, from)
val newLeft = doFrom(tree.left, from)
if (newLeft eq tree.left) tree
- else if (newLeft eq null) upd(tree.right, tree.key, tree.value)
+ else if (newLeft eq null) upd(tree.right, tree.key, tree.value, false)
else rebalance(tree, newLeft, tree.right)
}
private[this] def doTo[A, B](tree: Tree[A, B], to: A)(implicit ordering: Ordering[A]): Tree[A, B] = {
@@ -224,7 +225,7 @@ object RedBlackTree {
if (ordering.lt(to, tree.key)) return doTo(tree.left, to)
val newRight = doTo(tree.right, to)
if (newRight eq tree.right) tree
- else if (newRight eq null) upd(tree.left, tree.key, tree.value)
+ else if (newRight eq null) upd(tree.left, tree.key, tree.value, false)
else rebalance(tree, tree.left, newRight)
}
private[this] def doUntil[A, B](tree: Tree[A, B], until: A)(implicit ordering: Ordering[A]): Tree[A, B] = {
@@ -232,7 +233,7 @@ object RedBlackTree {
if (ordering.lteq(until, tree.key)) return doUntil(tree.left, until)
val newRight = doUntil(tree.right, until)
if (newRight eq tree.right) tree
- else if (newRight eq null) upd(tree.left, tree.key, tree.value)
+ else if (newRight eq null) upd(tree.left, tree.key, tree.value, false)
else rebalance(tree, tree.left, newRight)
}
private[this] def doRange[A, B](tree: Tree[A, B], from: A, until: A)(implicit ordering: Ordering[A]): Tree[A, B] = {
@@ -242,8 +243,8 @@ object RedBlackTree {
val newLeft = doFrom(tree.left, from)
val newRight = doUntil(tree.right, until)
if ((newLeft eq tree.left) && (newRight eq tree.right)) tree
- else if (newLeft eq null) upd(newRight, tree.key, tree.value);
- else if (newRight eq null) upd(newLeft, tree.key, tree.value);
+ else if (newLeft eq null) upd(newRight, tree.key, tree.value, false);
+ else if (newRight eq null) upd(newLeft, tree.key, tree.value, false);
else rebalance(tree, newLeft, newRight)
}
@@ -254,7 +255,7 @@ object RedBlackTree {
if (n > count) return doDrop(tree.right, n - count - 1)
val newLeft = doDrop(tree.left, n)
if (newLeft eq tree.left) tree
- else if (newLeft eq null) upd(tree.right, tree.key, tree.value)
+ else if (newLeft eq null) upd(tree.right, tree.key, tree.value, false)
else rebalance(tree, newLeft, tree.right)
}
private[this] def doTake[A: Ordering, B](tree: Tree[A, B], n: Int): Tree[A, B] = {
@@ -264,7 +265,7 @@ object RedBlackTree {
if (n <= count) return doTake(tree.left, n)
val newRight = doTake(tree.right, n - count - 1)
if (newRight eq tree.right) tree
- else if (newRight eq null) upd(tree.left, tree.key, tree.value)
+ else if (newRight eq null) upd(tree.left, tree.key, tree.value, false)
else rebalance(tree, tree.left, newRight)
}
private[this] def doSlice[A: Ordering, B](tree: Tree[A, B], from: Int, until: Int): Tree[A, B] = {
@@ -275,8 +276,8 @@ object RedBlackTree {
val newLeft = doDrop(tree.left, from)
val newRight = doTake(tree.right, until - count - 1)
if ((newLeft eq tree.left) && (newRight eq tree.right)) tree
- else if (newLeft eq null) upd(newRight, tree.key, tree.value)
- else if (newRight eq null) upd(newLeft, tree.key, tree.value)
+ else if (newLeft eq null) upd(newRight, tree.key, tree.value, false)
+ else if (newRight eq null) upd(newLeft, tree.key, tree.value, false)
else rebalance(tree, newLeft, newRight)
}
diff --git a/src/library/scala/collection/immutable/TreeMap.scala b/src/library/scala/collection/immutable/TreeMap.scala
index 4c1a5f2e03..51bc76efc3 100644
--- a/src/library/scala/collection/immutable/TreeMap.scala
+++ b/src/library/scala/collection/immutable/TreeMap.scala
@@ -131,7 +131,7 @@ class TreeMap[A, +B] private (tree: RB.Tree[A, B])(implicit val ordering: Orderi
* @param value the value to be associated with `key`
* @return a new $coll with the updated binding
*/
- override def updated [B1 >: B](key: A, value: B1): TreeMap[A, B1] = new TreeMap(RB.update(tree, key, value))
+ override def updated [B1 >: B](key: A, value: B1): TreeMap[A, B1] = new TreeMap(RB.update(tree, key, value, true))
/** Add a key/value pair to this map.
* @tparam B1 type of the value of the new binding, a supertype of `B`
@@ -171,7 +171,7 @@ class TreeMap[A, +B] private (tree: RB.Tree[A, B])(implicit val ordering: Orderi
*/
def insert [B1 >: B](key: A, value: B1): TreeMap[A, B1] = {
assert(!RB.contains(tree, key))
- new TreeMap(RB.update(tree, key, value))
+ new TreeMap(RB.update(tree, key, value, true))
}
def - (key:A): TreeMap[A, B] =
diff --git a/src/library/scala/collection/immutable/TreeSet.scala b/src/library/scala/collection/immutable/TreeSet.scala
index 882e828c5b..697da2bc4b 100644
--- a/src/library/scala/collection/immutable/TreeSet.scala
+++ b/src/library/scala/collection/immutable/TreeSet.scala
@@ -112,7 +112,7 @@ class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: Orderin
* @param elem a new element to add.
* @return a new $coll containing `elem` and all the elements of this $coll.
*/
- def + (elem: A): TreeSet[A] = newSet(RB.update(tree, elem, ()))
+ def + (elem: A): TreeSet[A] = newSet(RB.update(tree, elem, (), false))
/** A new `TreeSet` with the entry added is returned,
* assuming that elem is <em>not</em> in the TreeSet.
@@ -122,7 +122,7 @@ class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: Orderin
*/
def insert(elem: A): TreeSet[A] = {
assert(!RB.contains(tree, elem))
- newSet(RB.update(tree, elem, ()))
+ newSet(RB.update(tree, elem, (), false))
}
/** Creates a new `TreeSet` with the entry removed.
diff --git a/test/files/run/t5986.check b/test/files/run/t5986.check
new file mode 100644
index 0000000000..4101770c6d
--- /dev/null
+++ b/test/files/run/t5986.check
@@ -0,0 +1,15 @@
+Foo(bar, 1)
+Foo(bar, 1)
+Foo(bar, 1),Foo(baz, 3),Foo(bazz, 4)
+Foo(bar, 1)
+Foo(bar, 1)
+Foo(bar, 1),Foo(baz, 3),Foo(bazz, 4)
+Foo(bar, 1)
+Foo(bar, 1)
+Foo(bar, 1),Foo(baz, 3),Foo(bazz, 4)
+Foo(bar, 1)
+Foo(bar, 1)
+Foo(bar, 1),Foo(baz, 3),Foo(bazz, 4)
+Foo(bar, 1)
+Foo(bar, 1)
+Foo(bar, 1),Foo(baz, 3),Foo(bazz, 4) \ No newline at end of file
diff --git a/test/files/run/t5986.scala b/test/files/run/t5986.scala
new file mode 100644
index 0000000000..8cf7086f98
--- /dev/null
+++ b/test/files/run/t5986.scala
@@ -0,0 +1,36 @@
+
+
+
+import scala.collection._
+
+
+
+/** A sorted set should not replace elements when adding
+ * and the element already exists in the set.
+ */
+object Test {
+
+ class Foo(val name: String, val n: Int) {
+ override def equals(obj: Any): Boolean = obj match { case other: Foo => name == other.name; case _ => false }
+ override def hashCode = name.##
+ override def toString = "Foo(" + name + ", " + n + ")"
+ }
+
+ implicit val ordering: Ordering[Foo] = Ordering.fromLessThan[Foo] { (a, b) => a.name.compareTo(b.name) < 0 }
+
+ def check[S <: Set[Foo]](set: S) {
+ def output(s: Set[Foo]) = println(s.toList.sorted.mkString(","))
+ output(set + new Foo("bar", 2))
+ output(set ++ List(new Foo("bar", 2), new Foo("bar", 3), new Foo("bar", 4)))
+ output(set union Set(new Foo("bar", 2), new Foo("baz", 3), new Foo("bazz", 4)))
+ }
+
+ def main(args: Array[String]) {
+ check(Set(new Foo("bar", 1)))
+ check(immutable.Set(new Foo("bar", 1)))
+ check(mutable.Set(new Foo("bar", 1)))
+ check(immutable.SortedSet(new Foo("bar", 1)))
+ check(mutable.SortedSet(new Foo("bar", 1)))
+ }
+
+}
diff --git a/test/files/scalacheck/redblacktree.scala b/test/files/scalacheck/redblacktree.scala
index e4b356c889..e2609fa200 100644
--- a/test/files/scalacheck/redblacktree.scala
+++ b/test/files/scalacheck/redblacktree.scala
@@ -121,7 +121,7 @@ package scala.collection.immutable.redblacktree {
override type ModifyParm = Int
override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size + 1)
- override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = update(tree, generateKey(tree, parm), 0)
+ override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = update(tree, generateKey(tree, parm), 0, true)
def generateKey(tree: Tree[String, Int], parm: ModifyParm): String = nodeAt(tree, parm) match {
case Some((key, _)) => key.init.mkString + "MN"
@@ -144,7 +144,7 @@ package scala.collection.immutable.redblacktree {
override type ModifyParm = Int
override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size)
override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map {
- case (key, _) => update(tree, key, newValue)
+ case (key, _) => update(tree, key, newValue, true)
} getOrElse tree
property("update modifies values") = forAll(genInput) { case (tree, parm, newTree) =>