From 6b950741c58938eab922908ac4fb809b7ca68c01 Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Wed, 21 Dec 2011 09:28:42 +0100 Subject: Make sure the redblack test compiles and runs. --- test/files/scalacheck/redblack.scala | 76 +++++++++++++++++------------------- 1 file changed, 36 insertions(+), 40 deletions(-) (limited to 'test/files/scalacheck') diff --git a/test/files/scalacheck/redblack.scala b/test/files/scalacheck/redblack.scala index 1fcaa46f0e..011a5d0ca5 100644 --- a/test/files/scalacheck/redblack.scala +++ b/test/files/scalacheck/redblack.scala @@ -18,22 +18,18 @@ abstract class RedBlackTest extends Properties("RedBlack") { def minimumSize = 0 def maximumSize = 5 - object RedBlackTest extends scala.collection.immutable.RedBlack[String] { - def isSmaller(x: String, y: String) = x < y - } - - import RedBlackTest._ + import collection.immutable.RedBlack._ - def nodeAt[A](tree: Tree[A], n: Int): Option[(String, A)] = if (n < tree.iterator.size && n >= 0) + def nodeAt[A](tree: Tree[String, A], n: Int): Option[(String, A)] = if (n < tree.iterator.size && n >= 0) Some(tree.iterator.drop(n).next) else None - def treeContains[A](tree: Tree[A], key: String) = tree.iterator.map(_._1) contains key + def treeContains[A](tree: Tree[String, A], key: String) = tree.iterator.map(_._1) contains key - def mkTree(level: Int, parentIsBlack: Boolean = false, label: String = ""): Gen[Tree[Int]] = + def mkTree(level: Int, parentIsBlack: Boolean = false, label: String = ""): Gen[Tree[String, Int]] = if (level == 0) { - value(Empty) + value(Empty.empty) } else { for { oddOrEven <- choose(0, 2) @@ -56,10 +52,10 @@ abstract class RedBlackTest extends Properties("RedBlack") { } yield tree type ModifyParm - def genParm(tree: Tree[Int]): Gen[ModifyParm] - def modify(tree: Tree[Int], parm: ModifyParm): Tree[Int] + def genParm(tree: Tree[String, Int]): Gen[ModifyParm] + def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] - def genInput: Gen[(Tree[Int], ModifyParm, Tree[Int])] = for { + def genInput: Gen[(Tree[String, Int], ModifyParm, Tree[String, Int])] = for { tree <- genTree parm <- genParm(tree) } yield (tree, parm, modify(tree, parm)) @@ -68,30 +64,30 @@ abstract class RedBlackTest extends Properties("RedBlack") { trait RedBlackInvariants { self: RedBlackTest => - import RedBlackTest._ + import collection.immutable.RedBlack._ - def rootIsBlack[A](t: Tree[A]) = t.isBlack + def rootIsBlack[A](t: Tree[String, A]) = t.isBlack - def areAllLeavesBlack[A](t: Tree[A]): Boolean = t match { - case Empty => t.isBlack - case ne: NonEmpty[_] => List(ne.left, ne.right) forall areAllLeavesBlack + def areAllLeavesBlack[A](t: Tree[String, A]): Boolean = t match { + case Empty.Instance => t.isBlack + case ne: NonEmpty[_, _] => List(ne.left, ne.right) forall areAllLeavesBlack } - def areRedNodeChildrenBlack[A](t: Tree[A]): Boolean = t match { + def areRedNodeChildrenBlack[A](t: Tree[String, A]): Boolean = t match { case RedTree(_, _, left, right) => List(left, right) forall (t => t.isBlack && areRedNodeChildrenBlack(t)) case BlackTree(_, _, left, right) => List(left, right) forall areRedNodeChildrenBlack - case Empty => true + case Empty.Instance => true } - def blackNodesToLeaves[A](t: Tree[A]): List[Int] = t match { - case Empty => List(1) + def blackNodesToLeaves[A](t: Tree[String, A]): List[Int] = t match { + case Empty.Instance => List(1) case BlackTree(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves map (_ + 1) case RedTree(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves } - def areBlackNodesToLeavesEqual[A](t: Tree[A]): Boolean = t match { - case Empty => true - case ne: NonEmpty[_] => + def areBlackNodesToLeavesEqual[A](t: Tree[String, A]): Boolean = t match { + case Empty.Instance => true + case ne: NonEmpty[_, _] => ( blackNodesToLeaves(ne).distinct.size == 1 && areBlackNodesToLeavesEqual(ne.left) @@ -99,10 +95,10 @@ trait RedBlackInvariants { ) } - def orderIsPreserved[A](t: Tree[A]): Boolean = - t.iterator zip t.iterator.drop(1) forall { case (x, y) => isSmaller(x._1, y._1) } + def orderIsPreserved[A](t: Tree[String, A]): Boolean = + t.iterator zip t.iterator.drop(1) forall { case (x, y) => x._1 < y._1 } - def setup(invariant: Tree[Int] => Boolean) = forAll(genInput) { case (tree, parm, newTree) => + def setup(invariant: Tree[String, Int] => Boolean) = forAll(genInput) { case (tree, parm, newTree) => invariant(newTree) } @@ -114,13 +110,13 @@ trait RedBlackInvariants { } object TestInsert extends RedBlackTest with RedBlackInvariants { - import RedBlackTest._ + import collection.immutable.RedBlack._ override type ModifyParm = Int - override def genParm(tree: Tree[Int]): Gen[ModifyParm] = choose(0, tree.iterator.size + 1) - override def modify(tree: Tree[Int], parm: ModifyParm): Tree[Int] = tree update (generateKey(tree, parm), 0) + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, tree.iterator.size + 1) + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = tree update (generateKey(tree, parm), 0) - def generateKey(tree: Tree[Int], parm: ModifyParm): String = nodeAt(tree, parm) match { + def generateKey(tree: Tree[String, Int], parm: ModifyParm): String = nodeAt(tree, parm) match { case Some((key, _)) => key.init.mkString + "MN" case None => nodeAt(tree, parm - 1) match { case Some((key, _)) => key.init.mkString + "RN" @@ -134,13 +130,13 @@ object TestInsert extends RedBlackTest with RedBlackInvariants { } object TestModify extends RedBlackTest { - import RedBlackTest._ + import collection.immutable.RedBlack._ def newValue = 1 override def minimumSize = 1 override type ModifyParm = Int - override def genParm(tree: Tree[Int]): Gen[ModifyParm] = choose(0, tree.iterator.size) - override def modify(tree: Tree[Int], parm: ModifyParm): Tree[Int] = nodeAt(tree, parm) map { + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, tree.iterator.size) + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map { case (key, _) => tree update (key, newValue) } getOrElse tree @@ -152,12 +148,12 @@ object TestModify extends RedBlackTest { } object TestDelete extends RedBlackTest with RedBlackInvariants { - import RedBlackTest._ + import collection.immutable.RedBlack._ override def minimumSize = 1 override type ModifyParm = Int - override def genParm(tree: Tree[Int]): Gen[ModifyParm] = choose(0, tree.iterator.size) - override def modify(tree: Tree[Int], parm: ModifyParm): Tree[Int] = nodeAt(tree, parm) map { + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, tree.iterator.size) + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map { case (key, _) => tree delete key } getOrElse tree @@ -169,17 +165,17 @@ object TestDelete extends RedBlackTest with RedBlackInvariants { } object TestRange extends RedBlackTest with RedBlackInvariants { - import RedBlackTest._ + import collection.immutable.RedBlack._ override type ModifyParm = (Option[Int], Option[Int]) - override def genParm(tree: Tree[Int]): Gen[ModifyParm] = for { + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = for { from <- choose(0, tree.iterator.size) to <- choose(0, tree.iterator.size) suchThat (from <=) optionalFrom <- oneOf(Some(from), None, Some(from)) // Double Some(n) to get around a bug optionalTo <- oneOf(Some(to), None, Some(to)) // Double Some(n) to get around a bug } yield (optionalFrom, optionalTo) - override def modify(tree: Tree[Int], parm: ModifyParm): Tree[Int] = { + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = { val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) tree range (from, to) -- cgit v1.2.3 From b9699f999da24f72dca65ecfb066b0ac3151f2b5 Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Tue, 27 Dec 2011 10:23:04 +0100 Subject: Made RedBlack private to the scala.collection.immutable package. Use ArrayStack instead of Stack in TreeIterator for slightly increased performance. --- src/library/scala/collection/immutable/RedBlack.scala | 7 +++---- test/files/scalacheck/redblack.scala | 15 +++++++++------ 2 files changed, 12 insertions(+), 10 deletions(-) (limited to 'test/files/scalacheck') diff --git a/src/library/scala/collection/immutable/RedBlack.scala b/src/library/scala/collection/immutable/RedBlack.scala index 4b81182657..19e0e5ae55 100644 --- a/src/library/scala/collection/immutable/RedBlack.scala +++ b/src/library/scala/collection/immutable/RedBlack.scala @@ -15,7 +15,8 @@ package immutable * * @since 2.3 */ -object RedBlack extends Serializable { +private[immutable] +object RedBlack { private def blacken[A, B](t: Tree[A, B]): Tree[A, B] = t match { case RedTree(k, v, l, r) => BlackTree(k, v, l, r) @@ -302,8 +303,6 @@ object RedBlack extends Serializable { } private[this] class TreeIterator[A, B](tree: NonEmpty[A, B]) extends Iterator[(A, B)] { - import collection.mutable.Stack - override def hasNext: Boolean = !next.isEmpty override def next: (A, B) = next match { @@ -326,7 +325,7 @@ object RedBlack extends Serializable { } } - private[this] val path: Stack[NonEmpty[A, B]] = Stack.empty[NonEmpty[A, B]] + private[this] val path = mutable.ArrayStack.empty[NonEmpty[A, B]] addLeftMostBranchToPath(tree) private[this] var next: Tree[A, B] = path.pop() } diff --git a/test/files/scalacheck/redblack.scala b/test/files/scalacheck/redblack.scala index 011a5d0ca5..78fb645ce8 100644 --- a/test/files/scalacheck/redblack.scala +++ b/test/files/scalacheck/redblack.scala @@ -1,3 +1,4 @@ +import collection.immutable._ import org.scalacheck._ import Prop._ import Gen._ @@ -14,11 +15,12 @@ Both children of every red node are black. Every simple path from a given node to any of its descendant leaves contains the same number of black nodes. */ +package scala.collection.immutable { abstract class RedBlackTest extends Properties("RedBlack") { def minimumSize = 0 def maximumSize = 5 - import collection.immutable.RedBlack._ + import RedBlack._ def nodeAt[A](tree: Tree[String, A], n: Int): Option[(String, A)] = if (n < tree.iterator.size && n >= 0) Some(tree.iterator.drop(n).next) @@ -64,7 +66,7 @@ abstract class RedBlackTest extends Properties("RedBlack") { trait RedBlackInvariants { self: RedBlackTest => - import collection.immutable.RedBlack._ + import RedBlack._ def rootIsBlack[A](t: Tree[String, A]) = t.isBlack @@ -110,7 +112,7 @@ trait RedBlackInvariants { } object TestInsert extends RedBlackTest with RedBlackInvariants { - import collection.immutable.RedBlack._ + import RedBlack._ override type ModifyParm = Int override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, tree.iterator.size + 1) @@ -130,7 +132,7 @@ object TestInsert extends RedBlackTest with RedBlackInvariants { } object TestModify extends RedBlackTest { - import collection.immutable.RedBlack._ + import RedBlack._ def newValue = 1 override def minimumSize = 1 @@ -148,7 +150,7 @@ object TestModify extends RedBlackTest { } object TestDelete extends RedBlackTest with RedBlackInvariants { - import collection.immutable.RedBlack._ + import RedBlack._ override def minimumSize = 1 override type ModifyParm = Int @@ -165,7 +167,7 @@ object TestDelete extends RedBlackTest with RedBlackInvariants { } object TestRange extends RedBlackTest with RedBlackInvariants { - import collection.immutable.RedBlack._ + import RedBlack._ override type ModifyParm = (Option[Int], Option[Int]) override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = for { @@ -199,6 +201,7 @@ object TestRange extends RedBlackTest with RedBlackInvariants { filteredTree == newTree.iterator.map(_._1).toList } } +} object Test extends Properties("RedBlack") { include(TestInsert) -- cgit v1.2.3 From ad0b09c0c9606d43df7e3a76c535b3943e8d583a Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Wed, 28 Dec 2011 10:21:56 +0100 Subject: Added some tests for TreeMap/TreeSet. --- test/files/scalacheck/treemap.scala | 93 +++++++++++++++++++++++++++++++++++++ test/files/scalacheck/treeset.scala | 89 +++++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+) create mode 100644 test/files/scalacheck/treemap.scala create mode 100644 test/files/scalacheck/treeset.scala (limited to 'test/files/scalacheck') diff --git a/test/files/scalacheck/treemap.scala b/test/files/scalacheck/treemap.scala new file mode 100644 index 0000000000..43d307600d --- /dev/null +++ b/test/files/scalacheck/treemap.scala @@ -0,0 +1,93 @@ +import collection.immutable._ +import org.scalacheck._ +import Prop._ +import Gen._ +import Arbitrary._ +import util._ +import Buildable._ + +object Test extends Properties("TreeMap") { + implicit def arbTreeMap[A : Arbitrary : Ordering, B : Arbitrary]: Arbitrary[TreeMap[A, B]] = + Arbitrary(for { + keys <- listOf(arbitrary[A]) + values <- listOfN(keys.size, arbitrary[B]) + } yield TreeMap(keys zip values: _*)) + + property("foreach/iterator consistency") = forAll { (subject: TreeMap[Int, String]) => + val it = subject.iterator + var consistent = true + subject.foreach { element => + consistent &&= it.hasNext && element == it.next + } + consistent + } + + property("sorted") = forAll { (subject: TreeMap[Int, String]) => (subject.size >= 3) ==> { + subject.zip(subject.tail).forall { case (x, y) => x._1 < y._1 } + }} + + property("contains all") = forAll { (arr: List[(Int, String)]) => + val subject = TreeMap(arr: _*) + arr.map(_._1).forall(subject.contains(_)) + } + + property("size") = forAll { (elements: List[(Int, Int)]) => + val subject = TreeMap(elements: _*) + elements.map(_._1).distinct.size == subject.size + } + + property("toSeq") = forAll { (elements: List[(Int, Int)]) => + val subject = TreeMap(elements: _*) + elements.map(_._1).distinct.sorted == subject.toSeq.map(_._1) + } + + property("head") = forAll { (elements: List[Int]) => elements.nonEmpty ==> { + val subject = TreeMap(elements zip elements: _*) + elements.min == subject.head._1 + }} + + property("last") = forAll { (elements: List[Int]) => elements.nonEmpty ==> { + val subject = TreeMap(elements zip elements: _*) + elements.max == subject.last._1 + }} + + property("head/tail identity") = forAll { (subject: TreeMap[Int, String]) => subject.nonEmpty ==> { + subject == (subject.tail + subject.head) + }} + + property("init/last identity") = forAll { (subject: TreeMap[Int, String]) => subject.nonEmpty ==> { + subject == (subject.init + subject.last) + }} + + property("take") = forAll { (subject: TreeMap[Int, String]) => + val n = choose(0, subject.size).sample.get + n == subject.take(n).size && subject.take(n).forall(elt => subject.get(elt._1) == Some(elt._2)) + } + + property("drop") = forAll { (subject: TreeMap[Int, String]) => + val n = choose(0, subject.size).sample.get + (subject.size - n) == subject.drop(n).size && subject.drop(n).forall(elt => subject.get(elt._1) == Some(elt._2)) + } + + property("take/drop identity") = forAll { (subject: TreeMap[Int, String]) => + val n = choose(-1, subject.size + 1).sample.get + subject == subject.take(n) ++ subject.drop(n) + } + + property("splitAt") = forAll { (subject: TreeMap[Int, String]) => + val n = choose(-1, subject.size + 1).sample.get + val (prefix, suffix) = subject.splitAt(n) + prefix == subject.take(n) && suffix == subject.drop(n) + } + + property("remove single") = forAll { (subject: TreeMap[Int, String]) => subject.nonEmpty ==> { + val key = oneOf(subject.keys.toSeq).sample.get + val removed = subject - key + subject.contains(key) && !removed.contains(key) && subject.size - 1 == removed.size + }} + + property("remove all") = forAll { (subject: TreeMap[Int, String]) => + val result = subject.foldLeft(subject)((acc, elt) => acc - elt._1) + result.isEmpty + } +} diff --git a/test/files/scalacheck/treeset.scala b/test/files/scalacheck/treeset.scala new file mode 100644 index 0000000000..3cefef7040 --- /dev/null +++ b/test/files/scalacheck/treeset.scala @@ -0,0 +1,89 @@ +import collection.immutable._ +import org.scalacheck._ +import Prop._ +import Gen._ +import Arbitrary._ +import util._ + +object Test extends Properties("TreeSet") { + implicit def arbTreeSet[A : Arbitrary : Ordering]: Arbitrary[TreeSet[A]] = + Arbitrary(listOf(arbitrary[A]) map (elements => TreeSet(elements: _*))) + + property("foreach/iterator consistency") = forAll { (subject: TreeSet[Int]) => + val it = subject.iterator + var consistent = true + subject.foreach { element => + consistent &&= it.hasNext && element == it.next + } + consistent + } + + property("sorted") = forAll { (subject: TreeSet[Int]) => (subject.size >= 3) ==> { + subject.zip(subject.tail).forall { case (x, y) => x < y } + }} + + property("contains all") = forAll { (elements: List[Int]) => + val subject = TreeSet(elements: _*) + elements.forall(subject.contains) + } + + property("size") = forAll { (elements: List[Int]) => + val subject = TreeSet(elements: _*) + elements.distinct.size == subject.size + } + + property("toSeq") = forAll { (elements: List[Int]) => + val subject = TreeSet(elements: _*) + elements.distinct.sorted == subject.toSeq + } + + property("head") = forAll { (elements: List[Int]) => elements.nonEmpty ==> { + val subject = TreeSet(elements: _*) + elements.min == subject.head + }} + + property("last") = forAll { (elements: List[Int]) => elements.nonEmpty ==> { + val subject = TreeSet(elements: _*) + elements.max == subject.last + }} + + property("head/tail identity") = forAll { (subject: TreeSet[Int]) => subject.nonEmpty ==> { + subject == (subject.tail + subject.head) + }} + + property("init/last identity") = forAll { (subject: TreeSet[Int]) => subject.nonEmpty ==> { + subject == (subject.init + subject.last) + }} + + property("take") = forAll { (subject: TreeSet[Int]) => + val n = choose(0, subject.size).sample.get + n == subject.take(n).size && subject.take(n).forall(subject.contains) + } + + property("drop") = forAll { (subject: TreeSet[Int]) => + val n = choose(0, subject.size).sample.get + (subject.size - n) == subject.drop(n).size && subject.drop(n).forall(subject.contains) + } + + property("take/drop identity") = forAll { (subject: TreeSet[Int]) => + val n = choose(-1, subject.size + 1).sample.get + subject == subject.take(n) ++ subject.drop(n) + } + + property("splitAt") = forAll { (subject: TreeSet[Int]) => + val n = choose(-1, subject.size + 1).sample.get + val (prefix, suffix) = subject.splitAt(n) + prefix == subject.take(n) && suffix == subject.drop(n) + } + + property("remove single") = forAll { (subject: TreeSet[Int]) => subject.nonEmpty ==> { + val element = oneOf(subject.toSeq).sample.get + val removed = subject - element + subject.contains(element) && !removed.contains(element) && subject.size - 1 == removed.size + }} + + property("remove all") = forAll { (subject: TreeSet[Int]) => + val result = subject.foldLeft(subject)((acc, elt) => acc - elt) + result.isEmpty + } +} -- cgit v1.2.3 From 5c05f66b619ea9c8a543518fd9d7d601268c6f19 Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Mon, 2 Jan 2012 19:48:37 +0100 Subject: Use null to represent empty trees. Removed Empty/NonEmpty classes. --- .../scala/collection/immutable/RedBlack.scala | 569 ++++++++++----------- .../scala/collection/immutable/TreeMap.scala | 46 +- .../scala/collection/immutable/TreeSet.scala | 44 +- test/files/scalacheck/redblack.scala | 112 ++-- 4 files changed, 367 insertions(+), 404 deletions(-) (limited to 'test/files/scalacheck') diff --git a/src/library/scala/collection/immutable/RedBlack.scala b/src/library/scala/collection/immutable/RedBlack.scala index 3b16f719bf..2537d043fd 100644 --- a/src/library/scala/collection/immutable/RedBlack.scala +++ b/src/library/scala/collection/immutable/RedBlack.scala @@ -11,6 +11,8 @@ package scala.collection package immutable +import annotation.meta.getter + /** An object containing the RedBlack tree implementation used by for `TreeMaps` and `TreeSets`. * * @since 2.3 @@ -18,389 +20,354 @@ package immutable private[immutable] object RedBlack { - private def blacken[A, B](t: Tree[A, B]): Tree[A, B] = t.black + private def blacken[A, B](t: Tree[A, B]): Tree[A, B] = if (t eq null) null else t.black private def mkTree[A, B](isBlack: Boolean, k: A, v: B, l: Tree[A, B], r: Tree[A, B]) = if (isBlack) BlackTree(k, v, l, r) else RedTree(k, v, l, r) - def isRedTree[A, B](tree: Tree[A, B]) = tree.isInstanceOf[RedTree[_, _]] + + def isBlack(tree: Tree[_, _]) = (tree eq null) || isBlackTree(tree) + def isRedTree(tree: Tree[_, _]) = tree.isInstanceOf[RedTree[_, _]] def isBlackTree(tree: Tree[_, _]) = tree.isInstanceOf[BlackTree[_, _]] + def isEmpty(tree: Tree[_, _]): Boolean = tree eq null + + def contains[A](tree: Tree[A, _], x: A)(implicit ordering: Ordering[A]): Boolean = lookup(tree, x) ne null + def get[A, B](tree: Tree[A, B], x: A)(implicit ordering: Ordering[A]): Option[B] = lookup(tree, x) match { + case null => None + case tree => Some(tree.value) + } + @annotation.tailrec - def lookup[A, B](tree: Tree[A, B], x: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree eq Empty.Instance) tree else { - val cmp = ordering.compare(x, tree.key) - if (cmp < 0) lookup(tree.left, x) - else if (cmp > 0) lookup(tree.right, x) - else tree + def lookup[A, B](tree: Tree[A, B], x: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree eq null) null else { + val cmp = ordering.compare(x, tree.key) + if (cmp < 0) lookup(tree.left, x) + else if (cmp > 0) lookup(tree.right, x) + else tree } - sealed abstract class Tree[A, +B] extends Serializable { - def key: A - def value: B - def left: Tree[A, B] - def right: Tree[A, B] - def isEmpty: Boolean - def isBlack: Boolean - def lookup(x: A)(implicit ordering: Ordering[A]): Tree[A, B] - def update[B1 >: B](k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(k, v)) - def delete(k: A)(implicit ordering: Ordering[A]): Tree[A, B] = blacken(del(k)) - def range(from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] = blacken(rng(from, until)) - def foreach[U](f: ((A, B)) => U) - def foreachKey[U](f: A => U) - def iterator: Iterator[(A, B)] - def keyIterator: Iterator[A] - def upd[B1 >: B](k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] - def del(k: A)(implicit ordering: Ordering[A]): Tree[A, B] - def smallest: NonEmpty[A, B] - def greatest: NonEmpty[A, B] - def rng(from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] - def first : A - def last : A - def count : Int - protected[immutable] def nth(n: Int): NonEmpty[A, B] - def black: Tree[A, B] = this - def red: Tree[A, B] + + def count(tree: Tree[_, _]) = if (tree eq null) 0 else tree.count + def update[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(tree, k, v)) + def delete[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = blacken(del(tree, k)) + def range[A, B](tree: Tree[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] = blacken(rng(tree, from, until)) + + def smallest[A, B](tree: Tree[A, B]): Tree[A, B] = { + if (tree eq null) throw new NoSuchElementException("empty map") + var result = tree + while (result.left ne null) result = result.left + result } - sealed abstract class NonEmpty[A, +B](final val key: A, final val value: B, final val left: Tree[A, B], final val right: Tree[A, B]) extends Tree[A, B] with Serializable { - def isEmpty = false - def lookup(k: A)(implicit ordering: Ordering[A]): Tree[A, B] = { - val cmp = ordering.compare(k, key) - if (cmp < 0) left.lookup(k) - else if (cmp > 0) right.lookup(k) - else this - } - private[this] def balanceLeft[B1 >: B](isBlack: Boolean, z: A, zv: B, l: Tree[A, B1], d: Tree[A, B1])/*: NonEmpty[A, B1]*/ = { - if (isRedTree(l) && isRedTree(l.left)) - RedTree(l.key, l.value, BlackTree(l.left.key, l.left.value, l.left.left, l.left.right), BlackTree(z, zv, l.right, d)) - else if (isRedTree(l) && isRedTree(l.right)) - RedTree(l.right.key, l.right.value, BlackTree(l.key, l.value, l.left, l.right.left), BlackTree(z, zv, l.right.right, d)) - else - mkTree(isBlack, z, zv, l, d) - } - private[this] def balanceRight[B1 >: B](isBlack: Boolean, x: A, xv: B, a: Tree[A, B1], r: Tree[A, B1])/*: NonEmpty[A, B1]*/ = { - if (isRedTree(r) && isRedTree(r.left)) - RedTree(r.left.key, r.left.value, BlackTree(x, xv, a, r.left.left), BlackTree(r.key, r.value, r.left.right, r.right)) - else if (isRedTree(r) && isRedTree(r.right)) - RedTree(r.key, r.value, BlackTree(x, xv, a, r.left), BlackTree(r.right.key, r.right.value, r.right.left, r.right.right)) - else - mkTree(isBlack, x, xv, a, r) - } - def upd[B1 >: B](k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = { - val cmp = ordering.compare(k, key) - if (cmp < 0) balanceLeft(isBlack, key, value, left.upd(k, v), right) - else if (cmp > 0) balanceRight(isBlack, key, value, left, right.upd(k, v)) - else mkTree(isBlack, k, v, left, right) - } + def greatest[A, B](tree: Tree[A, B]): Tree[A, B] = { + if (tree eq null) throw new NoSuchElementException("empty map") + var result = tree + while (result.right ne null) result = result.right + result + } + + def foreach[A, B, U](tree: Tree[A, B], f: ((A, B)) => U): Unit = if (tree ne null) { + foreach(tree.left, f) + f((tree.key, tree.value)) + foreach(tree.right, f) + } + def foreachKey[A, U](tree: Tree[A, _], f: A => U): Unit = if (tree ne null) { + foreachKey(tree.left, f) + f(tree.key) + foreachKey(tree.right, f) + } + + def iterator[A, B](tree: Tree[A, B]): Iterator[(A, B)] = if (tree eq null) Iterator.empty else new TreeIterator(tree) + def keyIterator[A, _](tree: Tree[A, _]): Iterator[A] = if (tree eq null) Iterator.empty else new TreeKeyIterator(tree) + + private[this] def balanceLeft[A, B, B1 >: B](isBlack: Boolean, z: A, zv: B, l: Tree[A, B1], d: Tree[A, B1]): Tree[A, B1] = { + if (isRedTree(l) && isRedTree(l.left)) + RedTree(l.key, l.value, BlackTree(l.left.key, l.left.value, l.left.left, l.left.right), BlackTree(z, zv, l.right, d)) + else if (isRedTree(l) && isRedTree(l.right)) + RedTree(l.right.key, l.right.value, BlackTree(l.key, l.value, l.left, l.right.left), BlackTree(z, zv, l.right.right, d)) + else + mkTree(isBlack, z, zv, l, d) + } + private[this] def balanceRight[A, B, B1 >: B](isBlack: Boolean, x: A, xv: B, a: Tree[A, B1], r: Tree[A, B1]): Tree[A, B1] = { + if (isRedTree(r) && isRedTree(r.left)) + RedTree(r.left.key, r.left.value, BlackTree(x, xv, a, r.left.left), BlackTree(r.key, r.value, r.left.right, r.right)) + else if (isRedTree(r) && isRedTree(r.right)) + RedTree(r.key, r.value, BlackTree(x, xv, a, r.left), BlackTree(r.right.key, r.right.value, r.right.left, r.right.right)) + else + mkTree(isBlack, x, xv, a, r) + } + private[this] def upd[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = if (tree == null) { + RedTree(k, v, null, null) + } else { + val cmp = ordering.compare(k, tree.key) + if (cmp < 0) balanceLeft(tree.isBlack, tree.key, tree.value, upd(tree.left, k, v), tree.right) + else if (cmp > 0) balanceRight(tree.isBlack, tree.key, tree.value, tree.left, upd(tree.right, k, v)) + else mkTree(tree.isBlack, k, v, tree.left, tree.right) + } + // Based on Stefan Kahrs' Haskell version of Okasaki's Red&Black Trees // http://www.cse.unsw.edu.au/~dons/data/RedBlackTree.html - def del(k: A)(implicit ordering: Ordering[A]): Tree[A, B] = { - def balance(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tl)) { - if (isRedTree(tr)) { - RedTree(x, xv, tl.black, tr.black) - } else if (isRedTree(tl.left)) { - RedTree(tl.key, tl.value, tl.left.black, BlackTree(x, xv, tl.right, tr)) - } else if (isRedTree(tl.right)) { - RedTree(tl.right.key, tl.right.value, BlackTree(tl.key, tl.value, tl.left, tl.right.left), BlackTree(x, xv, tl.right.right, tr)) - } else { - BlackTree(x, xv, tl, tr) - } - } else if (isRedTree(tr)) { - if (isRedTree(tr.right)) { - RedTree(tr.key, tr.value, BlackTree(x, xv, tl, tr.left), tr.right.black) - } else if (isRedTree(tr.left)) { - RedTree(tr.left.key, tr.left.value, BlackTree(x, xv, tl, tr.left.left), BlackTree(tr.key, tr.value, tr.left.right, tr.right)) - } else { - BlackTree(x, xv, tl, tr) - } + private[this] def del[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree == null) null else { + def balance(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tl)) { + if (isRedTree(tr)) { + RedTree(x, xv, tl.black, tr.black) + } else if (isRedTree(tl.left)) { + RedTree(tl.key, tl.value, tl.left.black, BlackTree(x, xv, tl.right, tr)) + } else if (isRedTree(tl.right)) { + RedTree(tl.right.key, tl.right.value, BlackTree(tl.key, tl.value, tl.left, tl.right.left), BlackTree(x, xv, tl.right.right, tr)) } else { BlackTree(x, xv, tl, tr) } - def subl(t: Tree[A, B]) = - if (t.isInstanceOf[BlackTree[_, _]]) t.red - else sys.error("Defect: invariance violation; expected black, got "+t) - - def balLeft(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tl)) { - RedTree(x, xv, tl.black, tr) - } else if (isBlackTree(tr)) { - balance(x, xv, tl, tr.red) - } else if (isRedTree(tr) && isBlackTree(tr.left)) { - RedTree(tr.left.key, tr.left.value, BlackTree(x, xv, tl, tr.left.left), balance(tr.key, tr.value, tr.left.right, subl(tr.right))) + } else if (isRedTree(tr)) { + if (isRedTree(tr.right)) { + RedTree(tr.key, tr.value, BlackTree(x, xv, tl, tr.left), tr.right.black) + } else if (isRedTree(tr.left)) { + RedTree(tr.left.key, tr.left.value, BlackTree(x, xv, tl, tr.left.left), BlackTree(tr.key, tr.value, tr.left.right, tr.right)) } else { - sys.error("Defect: invariance violation at "+right) + BlackTree(x, xv, tl, tr) } - def balRight(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tr)) { - RedTree(x, xv, tl, tr.black) - } else if (isBlackTree(tl)) { - balance(x, xv, tl.red, tr) - } else if (isRedTree(tl) && isBlackTree(tl.right)) { - RedTree(tl.right.key, tl.right.value, balance(tl.key, tl.value, subl(tl.left), tl.right.left), BlackTree(x, xv, tl.right.right, tr)) + } else { + BlackTree(x, xv, tl, tr) + } + def subl(t: Tree[A, B]) = + if (t.isInstanceOf[BlackTree[_, _]]) t.red + else sys.error("Defect: invariance violation; expected black, got "+t) + + def balLeft(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tl)) { + RedTree(x, xv, tl.black, tr) + } else if (isBlackTree(tr)) { + balance(x, xv, tl, tr.red) + } else if (isRedTree(tr) && isBlackTree(tr.left)) { + RedTree(tr.left.key, tr.left.value, BlackTree(x, xv, tl, tr.left.left), balance(tr.key, tr.value, tr.left.right, subl(tr.right))) + } else { + sys.error("Defect: invariance violation at ") // TODO + } + def balRight(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tr)) { + RedTree(x, xv, tl, tr.black) + } else if (isBlackTree(tl)) { + balance(x, xv, tl.red, tr) + } else if (isRedTree(tl) && isBlackTree(tl.right)) { + RedTree(tl.right.key, tl.right.value, balance(tl.key, tl.value, subl(tl.left), tl.right.left), BlackTree(x, xv, tl.right.right, tr)) + } else { + sys.error("Defect: invariance violation at ") // TODO + } + def delLeft = if (isBlackTree(tree.left)) balLeft(tree.key, tree.value, del(tree.left, k), tree.right) else RedTree(tree.key, tree.value, del(tree.left, k), tree.right) + def delRight = if (isBlackTree(tree.right)) balRight(tree.key, tree.value, tree.left, del(tree.right, k)) else RedTree(tree.key, tree.value, tree.left, del(tree.right, k)) + def append(tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] = if (tl eq null) { + tr + } else if (tr eq null) { + tl + } else if (isRedTree(tl) && isRedTree(tr)) { + val bc = append(tl.right, tr.left) + if (isRedTree(bc)) { + RedTree(bc.key, bc.value, RedTree(tl.key, tl.value, tl.left, bc.left), RedTree(tr.key, tr.value, bc.right, tr.right)) } else { - sys.error("Defect: invariance violation at "+left) + RedTree(tl.key, tl.value, tl.left, RedTree(tr.key, tr.value, bc, tr.right)) } - def delLeft = if (isBlackTree(left)) balLeft(key, value, left.del(k), right) else RedTree(key, value, left.del(k), right) - def delRight = if (isBlackTree(right)) balRight(key, value, left, right.del(k)) else RedTree(key, value, left, right.del(k)) - def append(tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] = if (tl eq Empty.Instance) { - tr - } else if (tr eq Empty.Instance) { - tl - } else if (isRedTree(tl) && isRedTree(tr)) { - val bc = append(tl.right, tr.left) - if (isRedTree(bc)) { - RedTree(bc.key, bc.value, RedTree(tl.key, tl.value, tl.left, bc.left), RedTree(tr.key, tr.value, bc.right, tr.right)) - } else { - RedTree(tl.key, tl.value, tl.left, RedTree(tr.key, tr.value, bc, tr.right)) - } - } else if (isBlackTree(tl) && isBlackTree(tr)) { - val bc = append(tl.right, tr.left) - if (isRedTree(bc)) { - RedTree(bc.key, bc.value, BlackTree(tl.key, tl.value, tl.left, bc.left), BlackTree(tr.key, tr.value, bc.right, tr.right)) - } else { - balLeft(tl.key, tl.value, tl.left, BlackTree(tr.key, tr.value, bc, tr.right)) - } - } else if (isRedTree(tr)) { - RedTree(tr.key, tr.value, append(tl, tr.left), tr.right) - } else if (isRedTree(tl)) { - RedTree(tl.key, tl.value, tl.left, append(tl.right, tr)) + } else if (isBlackTree(tl) && isBlackTree(tr)) { + val bc = append(tl.right, tr.left) + if (isRedTree(bc)) { + RedTree(bc.key, bc.value, BlackTree(tl.key, tl.value, tl.left, bc.left), BlackTree(tr.key, tr.value, bc.right, tr.right)) } else { - sys.error("unmatched tree on append: " + tl + ", " + tr) + balLeft(tl.key, tl.value, tl.left, BlackTree(tr.key, tr.value, bc, tr.right)) } - - val cmp = ordering.compare(k, key) - if (cmp < 0) delLeft - else if (cmp > 0) delRight - else append(left, right) + } else if (isRedTree(tr)) { + RedTree(tr.key, tr.value, append(tl, tr.left), tr.right) + } else if (isRedTree(tl)) { + RedTree(tl.key, tl.value, tl.left, append(tl.right, tr)) + } else { + sys.error("unmatched tree on append: " + tl + ", " + tr) } - def smallest: NonEmpty[A, B] = if (left eq Empty.Instance) this else left.smallest - def greatest: NonEmpty[A, B] = if (right eq Empty.Instance) this else right.greatest + val cmp = ordering.compare(k, tree.key) + if (cmp < 0) delLeft + else if (cmp > 0) delRight + else append(tree.left, tree.right) + } - def iterator: Iterator[(A, B)] = new TreeIterator(this) - def keyIterator: Iterator[A] = new TreeKeyIterator(this) + private[this] def rng[A, B](tree: Tree[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] = { + if (tree eq null) return null + if (from == None && until == None) return tree + if (from != None && ordering.lt(tree.key, from.get)) return rng(tree.right, from, until); + if (until != None && ordering.lteq(until.get, tree.key)) return rng(tree.left, from, until); + val newLeft = rng(tree.left, from, None) + val newRight = rng(tree.right, None, until) + if ((newLeft eq tree.left) && (newRight eq tree.right)) tree + else if (newLeft eq null) upd(newRight, tree.key, tree.value); + else if (newRight eq null) upd(newLeft, tree.key, tree.value); + else rebalance(tree, newLeft, newRight) + } - override def foreach[U](f: ((A, B)) => U) { - if (left ne Empty.Instance) left foreach f - f((key, value)) - if (right ne Empty.Instance) right foreach f + // The zipper returned might have been traversed left-most (always the left child) + // or right-most (always the right child). Left trees are traversed right-most, + // and right trees are traversed leftmost. + + // Returns the zipper for the side with deepest black nodes depth, a flag + // indicating whether the trees were unbalanced at all, and a flag indicating + // whether the zipper was traversed left-most or right-most. + + // If the trees were balanced, returns an empty zipper + private[this] def compareDepth[A, B](left: Tree[A, B], right: Tree[A, B]): (List[Tree[A, B]], Boolean, Boolean, Int) = { + // Once a side is found to be deeper, unzip it to the bottom + def unzip(zipper: List[Tree[A, B]], leftMost: Boolean): List[Tree[A, B]] = { + val next = if (leftMost) zipper.head.left else zipper.head.right + next match { + case null => zipper + case node => unzip(node :: zipper, leftMost) + } } - override def foreachKey[U](f: A => U) { - if (left ne Empty.Instance) left foreachKey f - f(key) - if (right ne Empty.Instance) right foreachKey f + // Unzip left tree on the rightmost side and right tree on the leftmost side until one is + // found to be deeper, or the bottom is reached + def unzipBoth(left: Tree[A, B], + right: Tree[A, B], + leftZipper: List[Tree[A, B]], + rightZipper: List[Tree[A, B]], + smallerDepth: Int): (List[Tree[A, B]], Boolean, Boolean, Int) = { + if (isBlackTree(left) && isBlackTree(right)) { + unzipBoth(left.right, right.left, left :: leftZipper, right :: rightZipper, smallerDepth + 1) + } else if (isRedTree(left) && isRedTree(right)) { + unzipBoth(left.right, right.left, left :: leftZipper, right :: rightZipper, smallerDepth) + } else if (isRedTree(right)) { + unzipBoth(left, right.left, leftZipper, right :: rightZipper, smallerDepth) + } else if (isRedTree(left)) { + unzipBoth(left.right, right, left :: leftZipper, rightZipper, smallerDepth) + } else if ((left eq null) && (right eq null)) { + (Nil, true, false, smallerDepth) + } else if ((left eq null) && isBlackTree(right)) { + val leftMost = true + (unzip(right :: rightZipper, leftMost), false, leftMost, smallerDepth) + } else if (isBlackTree(left) && (right eq null)) { + val leftMost = false + (unzip(left :: leftZipper, leftMost), false, leftMost, smallerDepth) + } else { + sys.error("unmatched trees in unzip: " + left + ", " + right) + } } - - override def rng(from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] = { - if (from == None && until == None) return this - if (from != None && ordering.lt(key, from.get)) return right.rng(from, until); - if (until != None && ordering.lteq(until.get, key)) return left.rng(from, until); - val newLeft = left.rng(from, None) - val newRight = right.rng(None, until) - if ((newLeft eq left) && (newRight eq right)) this - else if (newLeft eq Empty.Instance) newRight.upd(key, value); - else if (newRight eq Empty.Instance) newLeft.upd(key, value); - else rebalance(newLeft, newRight) + unzipBoth(left, right, Nil, Nil, 0) + } + private[this] def rebalance[A, B](tree: Tree[A, B], newLeft: Tree[A, B], newRight: Tree[A, B]) = { + // This is like drop(n-1), but only counting black nodes + def findDepth(zipper: List[Tree[A, B]], depth: Int): List[Tree[A, B]] = zipper match { + case head :: tail if isBlackTree(head) => + if (depth == 1) zipper else findDepth(tail, depth - 1) + case _ :: tail => findDepth(tail, depth) + case Nil => sys.error("Defect: unexpected empty zipper while computing range") } - // The zipper returned might have been traversed left-most (always the left child) - // or right-most (always the right child). Left trees are traversed right-most, - // and right trees are traversed leftmost. - - // Returns the zipper for the side with deepest black nodes depth, a flag - // indicating whether the trees were unbalanced at all, and a flag indicating - // whether the zipper was traversed left-most or right-most. - - // If the trees were balanced, returns an empty zipper - private[this] def compareDepth(left: Tree[A, B], right: Tree[A, B]): (List[NonEmpty[A, B]], Boolean, Boolean, Int) = { - // Once a side is found to be deeper, unzip it to the bottom - def unzip(zipper: List[NonEmpty[A, B]], leftMost: Boolean): List[NonEmpty[A, B]] = { - val next = if (leftMost) zipper.head.left else zipper.head.right - next match { - case node: NonEmpty[_, _] => unzip(node :: zipper, leftMost) - case _ => zipper - } + // Blackening the smaller tree avoids balancing problems on union; + // this can't be done later, though, or it would change the result of compareDepth + val blkNewLeft = blacken(newLeft) + val blkNewRight = blacken(newRight) + val (zipper, levelled, leftMost, smallerDepth) = compareDepth(blkNewLeft, blkNewRight) + + if (levelled) { + BlackTree(tree.key, tree.value, blkNewLeft, blkNewRight) + } else { + val zipFrom = findDepth(zipper, smallerDepth) + val union = if (leftMost) { + RedTree(tree.key, tree.value, blkNewLeft, zipFrom.head) + } else { + RedTree(tree.key, tree.value, zipFrom.head, blkNewRight) } - - // Unzip left tree on the rightmost side and right tree on the leftmost side until one is - // found to be deeper, or the bottom is reached - def unzipBoth(left: Tree[A, B], - right: Tree[A, B], - leftZipper: List[NonEmpty[A, B]], - rightZipper: List[NonEmpty[A, B]], - smallerDepth: Int): (List[NonEmpty[A, B]], Boolean, Boolean, Int) = { - lazy val l = left.asInstanceOf[NonEmpty[A, B]] - lazy val r = right.asInstanceOf[NonEmpty[A, B]] - if (isBlackTree(left) && isBlackTree(right)) { - unzipBoth(l.right, r.left, l :: leftZipper, r :: rightZipper, smallerDepth + 1) - } else if (isRedTree(left) && isRedTree(right)) { - unzipBoth(l.right, r.left, l :: leftZipper, r :: rightZipper, smallerDepth) - } else if (isRedTree(right)) { - unzipBoth(left, r.left, leftZipper, r :: rightZipper, smallerDepth) - } else if (isRedTree(left)) { - unzipBoth(l.right, right, l :: leftZipper, rightZipper, smallerDepth) - } else if ((left eq Empty.Instance) && (right eq Empty.Instance)) { - (Nil, true, false, smallerDepth) - } else if ((left eq Empty.Instance) && isBlackTree(right)) { - val leftMost = true - (unzip(r :: rightZipper, leftMost), false, leftMost, smallerDepth) - } else if (isBlackTree(left) && (right eq Empty.Instance)) { - val leftMost = false - (unzip(l :: leftZipper, leftMost), false, leftMost, smallerDepth) - } else { - sys.error("unmatched trees in unzip: " + left + ", " + right) - } + val zippedTree = zipFrom.tail.foldLeft(union: Tree[A, B]) { (tree, node) => + if (leftMost) + balanceLeft(node.isBlack, node.key, node.value, tree, node.right) + else + balanceRight(node.isBlack, node.key, node.value, node.left, tree) } - unzipBoth(left, right, Nil, Nil, 0) + zippedTree } + } - private[this] def rebalance(newLeft: Tree[A, B], newRight: Tree[A, B]) = { - // This is like drop(n-1), but only counting black nodes - def findDepth(zipper: List[NonEmpty[A, B]], depth: Int): List[NonEmpty[A, B]] = zipper match { - case head :: tail if isBlackTree(head) => - if (depth == 1) zipper else findDepth(tail, depth - 1) - case _ :: tail => findDepth(tail, depth) - case Nil => sys.error("Defect: unexpected empty zipper while computing range") - } - - // Blackening the smaller tree avoids balancing problems on union; - // this can't be done later, though, or it would change the result of compareDepth - val blkNewLeft = blacken(newLeft) - val blkNewRight = blacken(newRight) - val (zipper, levelled, leftMost, smallerDepth) = compareDepth(blkNewLeft, blkNewRight) - - if (levelled) { - BlackTree(key, value, blkNewLeft, blkNewRight) - } else { - val zipFrom = findDepth(zipper, smallerDepth) - val union = if (leftMost) { - RedTree(key, value, blkNewLeft, zipFrom.head) - } else { - RedTree(key, value, zipFrom.head, blkNewRight) - } - val zippedTree = zipFrom.tail.foldLeft(union: Tree[A, B]) { (tree, node) => - if (leftMost) - balanceLeft(node.isBlack, node.key, node.value, tree, node.right) - else - balanceRight(node.isBlack, node.key, node.value, node.left, tree) - } - zippedTree - } - } - def first = if (left eq Empty.Instance) key else left.first - def last = if (right eq Empty.Instance) key else right.last - val count = 1 + left.count + right.count - protected[immutable] def nth(n: Int) = { - val count = left.count + sealed abstract class Tree[A, +B]( + @(inline @getter) final val key: A, + @(inline @getter) final val value: B, + @(inline @getter) final val left: Tree[A, B], + @(inline @getter) final val right: Tree[A, B]) + extends Serializable { + @(inline @getter) final val count: Int = 1 + RedBlack.count(left) + RedBlack.count(right) + def isBlack: Boolean + def nth(n: Int): Tree[A, B] = { + val count = RedBlack.count(left) if (n < count) left.nth(n) else if (n > count) right.nth(n - count - 1) else this } + def black: Tree[A, B] + def red: Tree[A, B] } - object Empty { - def empty[A]: Tree[A, Nothing] = Instance.asInstanceOf[Tree[A, Nothing]] - final val Instance: Tree[_ >: Nothing, Nothing] = Empty[Nothing]() - } - final case class Empty[A] private () extends Tree[A, Nothing] { - def key = throw new NoSuchElementException("empty map") - def value = throw new NoSuchElementException("empty map") - def left = this - def right = this - def isEmpty = true - def isBlack = true - def lookup(k: A)(implicit ordering: Ordering[A]): Tree[A, Nothing] = this - def upd[B](k: A, v: B)(implicit ordering: Ordering[A]): Tree[A, B] = RedTree(k, v, this, this) - def del(k: A)(implicit ordering: Ordering[A]): Tree[A, Nothing] = this - def smallest: NonEmpty[A, Nothing] = throw new NoSuchElementException("empty map") - def greatest: NonEmpty[A, Nothing] = throw new NoSuchElementException("empty map") - def iterator: Iterator[(A, Nothing)] = Iterator.empty - def keyIterator: Iterator[A] = Iterator.empty - - override def foreach[U](f: ((A, Nothing)) => U) {} - override def foreachKey[U](f: A => U) {} - - def rng(from: Option[A], until: Option[A])(implicit ordering: Ordering[A]) = this - def first = throw new NoSuchElementException("empty map") - def last = throw new NoSuchElementException("empty map") - def count = 0 - protected[immutable] def nth(n: Int) = throw new NoSuchElementException("empty map") - override def red = sys.error("cannot make leaf red") - - override def toString() = "Empty" - - private def readResolve() = Empty.empty - } final class RedTree[A, +B](key: A, - value: B, - left: Tree[A, B], - right: Tree[A, B]) extends NonEmpty[A, B](key, value, left, right) { - def isBlack = false + value: B, + left: Tree[A, B], + right: Tree[A, B]) extends Tree[A, B](key, value, left, right) { + override def isBlack = false override def black = BlackTree(key, value, left, right) override def red = this + override def toString = "RedTree(" + key + ", " + value + ", " + left + ", " + right + ")" } object RedTree { def apply[A, B](key: A, value: B, left: Tree[A, B], right: Tree[A, B]) = new RedTree(key, value, left, right) def unapply[A, B](t: RedTree[A, B]) = Some((t.key, t.value, t.left, t.right)) } final class BlackTree[A, +B](key: A, - value: B, - left: Tree[A, B], - right: Tree[A, B]) extends NonEmpty[A, B](key, value, left, right) { - def isBlack = true + value: B, + left: Tree[A, B], + right: Tree[A, B]) extends Tree[A, B](key, value, left, right) { + override def isBlack = true + override def black = this override def red = RedTree(key, value, left, right) + override def toString = "BlackTree(" + key + ", " + value + ", " + left + ", " + right + ")" } object BlackTree { def apply[A, B](key: A, value: B, left: Tree[A, B], right: Tree[A, B]) = new BlackTree(key, value, left, right) def unapply[A, B](t: BlackTree[A, B]) = Some((t.key, t.value, t.left, t.right)) } - private[this] class TreeIterator[A, B](tree: NonEmpty[A, B]) extends Iterator[(A, B)] { - override def hasNext: Boolean = next ne Empty.Instance + private[this] class TreeIterator[A, B](tree: Tree[A, B]) extends Iterator[(A, B)] { + override def hasNext: Boolean = next ne null override def next: (A, B) = next match { - case Empty.Instance => + case null => throw new NoSuchElementException("next on empty iterator") - case tree: NonEmpty[A, B] => + case tree => addLeftMostBranchToPath(tree.right) - next = if (path.isEmpty) Empty.empty else path.pop() + next = if (path.isEmpty) null else path.pop() (tree.key, tree.value) } @annotation.tailrec private[this] def addLeftMostBranchToPath(tree: Tree[A, B]) { - tree match { - case Empty.Instance => - case tree: NonEmpty[A, B] => - path.push(tree) - addLeftMostBranchToPath(tree.left) + if (tree ne null) { + path.push(tree) + addLeftMostBranchToPath(tree.left) } } - private[this] val path = mutable.ArrayStack.empty[NonEmpty[A, B]] + private[this] val path = mutable.ArrayStack.empty[Tree[A, B]] addLeftMostBranchToPath(tree) private[this] var next: Tree[A, B] = path.pop() } - private[this] class TreeKeyIterator[A](tree: NonEmpty[A, _]) extends Iterator[A] { - override def hasNext: Boolean = next ne Empty.Instance + private[this] class TreeKeyIterator[A](tree: Tree[A, _]) extends Iterator[A] { + override def hasNext: Boolean = next ne null override def next: A = next match { - case Empty.Instance => + case null => throw new NoSuchElementException("next on empty iterator") - case tree: NonEmpty[A, _] => + case tree => addLeftMostBranchToPath(tree.right) - next = if (path.isEmpty) Empty.empty else path.pop() + next = if (path.isEmpty) null else path.pop() tree.key } @annotation.tailrec private[this] def addLeftMostBranchToPath(tree: Tree[A, _]) { - tree match { - case Empty.Instance => - case tree: NonEmpty[A, _] => - path.push(tree) - addLeftMostBranchToPath(tree.left) + if (tree ne null) { + path.push(tree) + addLeftMostBranchToPath(tree.left) } } - private[this] val path = mutable.ArrayStack.empty[NonEmpty[A, _]] + private[this] val path = mutable.ArrayStack.empty[Tree[A, _]] addLeftMostBranchToPath(tree) private[this] var next: Tree[A, _] = path.pop() } diff --git a/src/library/scala/collection/immutable/TreeMap.scala b/src/library/scala/collection/immutable/TreeMap.scala index 48a0bc3d44..45e936444f 100644 --- a/src/library/scala/collection/immutable/TreeMap.scala +++ b/src/library/scala/collection/immutable/TreeMap.scala @@ -51,39 +51,39 @@ class TreeMap[A, +B] private (tree: RedBlack.Tree[A, B])(implicit val ordering: with MapLike[A, B, TreeMap[A, B]] with Serializable { - import RedBlack._ + import immutable.{RedBlack => RB} def isSmaller(x: A, y: A) = ordering.lt(x, y) override protected[this] def newBuilder : Builder[(A, B), TreeMap[A, B]] = TreeMap.newBuilder[A, B] - override def size = tree.count + override def size = RB.count(tree) - def this()(implicit ordering: Ordering[A]) = this(RedBlack.Empty.empty)(ordering) + def this()(implicit ordering: Ordering[A]) = this(null)(ordering) override def rangeImpl(from : Option[A], until : Option[A]): TreeMap[A,B] = { - val ntree = tree.range(from,until) + val ntree = RB.range(tree, from,until) new TreeMap[A,B](ntree) } - override def firstKey = tree.first - override def lastKey = tree.last + override def firstKey = RB.smallest(tree).key + override def lastKey = RB.greatest(tree).key override def compare(k0: A, k1: A): Int = ordering.compare(k0, k1) override def head = { - val smallest = tree.smallest + val smallest = RB.smallest(tree) (smallest.key, smallest.value) } - override def headOption = if (tree.isEmpty) None else Some(head) + override def headOption = if (RB.isEmpty(tree)) None else Some(head) override def last = { - val greatest = tree.greatest + val greatest = RB.greatest(tree) (greatest.key, greatest.value) } - override def lastOption = if (tree.isEmpty) None else Some(last) + override def lastOption = if (RB.isEmpty(tree)) None else Some(last) - override def tail = new TreeMap(tree.delete(firstKey)) - override def init = new TreeMap(tree.delete(lastKey)) + override def tail = new TreeMap(RB.delete(tree, firstKey)) + override def init = new TreeMap(RB.delete(tree, lastKey)) override def drop(n: Int) = { if (n <= 0) this @@ -134,7 +134,7 @@ class TreeMap[A, +B] private (tree: RedBlack.Tree[A, B])(implicit val ordering: * @param value the value to be associated with `key` * @return a new $coll with the updated binding */ - override def updated [B1 >: B](key: A, value: B1): TreeMap[A, B1] = new TreeMap(tree.update(key, value)) + override def updated [B1 >: B](key: A, value: B1): TreeMap[A, B1] = new TreeMap(RB.update(tree, key, value)) /** Add a key/value pair to this map. * @tparam B1 type of the value of the new binding, a supertype of `B` @@ -175,13 +175,13 @@ class TreeMap[A, +B] private (tree: RedBlack.Tree[A, B])(implicit val ordering: * @return a new $coll with the inserted binding, if it wasn't present in the map */ def insert [B1 >: B](key: A, value: B1): TreeMap[A, B1] = { - assert(tree.lookup(key).isEmpty) - new TreeMap(tree.update(key, value)) + assert(!RB.contains(tree, key)) + new TreeMap(RB.update(tree, key, value)) } def - (key:A): TreeMap[A, B] = - if (tree.lookup(key).isEmpty) this - else new TreeMap(tree.delete(key)) + if (!RB.contains(tree, key)) this + else new TreeMap(RB.delete(tree, key)) /** Check if this map maps `key` to a value and return the * value if it exists. @@ -189,21 +189,19 @@ class TreeMap[A, +B] private (tree: RedBlack.Tree[A, B])(implicit val ordering: * @param key the key of the mapping of interest * @return the value of the mapping, if it exists */ - override def get(key: A): Option[B] = lookup(tree, key) match { - case n: NonEmpty[_, _] => Some(n.value) - case _ => None - } + override def get(key: A): Option[B] = RB.get(tree, key) /** Creates a new iterator over all elements contained in this * object. * * @return the new iterator */ - def iterator: Iterator[(A, B)] = tree.iterator + def iterator: Iterator[(A, B)] = RB.iterator(tree) - override def toStream: Stream[(A, B)] = tree.iterator.toStream + override def contains(key: A): Boolean = RB.contains(tree, key) + override def isDefinedAt(key: A): Boolean = RB.contains(tree, key) - override def foreach[U](f : ((A,B)) => U) = tree foreach f + override def foreach[U](f : ((A,B)) => U) = RB.foreach(tree, f) } diff --git a/src/library/scala/collection/immutable/TreeSet.scala b/src/library/scala/collection/immutable/TreeSet.scala index 74c63d0eb5..00ebeab868 100644 --- a/src/library/scala/collection/immutable/TreeSet.scala +++ b/src/library/scala/collection/immutable/TreeSet.scala @@ -50,19 +50,19 @@ object TreeSet extends ImmutableSortedSetFactory[TreeSet] { class TreeSet[A] private (tree: RedBlack.Tree[A, Unit])(implicit val ordering: Ordering[A]) extends SortedSet[A] with SortedSetLike[A, TreeSet[A]] with Serializable { - import RedBlack._ + import immutable.{RedBlack => RB} override def stringPrefix = "TreeSet" - override def size = tree.count + override def size = RB.count(tree) - override def head = tree.smallest.key - override def headOption = if (tree.isEmpty) None else Some(head) - override def last = tree.greatest.key - override def lastOption = if (tree.isEmpty) None else Some(last) + override def head = RB.smallest(tree).key + override def headOption = if (RB.isEmpty(tree)) None else Some(head) + override def last = RB.greatest(tree).key + override def lastOption = if (RB.isEmpty(tree)) None else Some(last) - override def tail = new TreeSet(tree.delete(firstKey)) - override def init = new TreeSet(tree.delete(lastKey)) + override def tail = new TreeSet(RB.delete(tree, firstKey)) + override def init = new TreeSet(RB.delete(tree, lastKey)) override def drop(n: Int) = { if (n <= 0) this @@ -102,7 +102,7 @@ class TreeSet[A] private (tree: RedBlack.Tree[A, Unit])(implicit val ordering: O def isSmaller(x: A, y: A) = compare(x,y) < 0 - def this()(implicit ordering: Ordering[A]) = this(RedBlack.Empty.empty)(ordering) + def this()(implicit ordering: Ordering[A]) = this(null)(ordering) private def newSet(t: RedBlack.Tree[A, Unit]) = new TreeSet[A](t) @@ -115,7 +115,7 @@ class TreeSet[A] private (tree: RedBlack.Tree[A, Unit])(implicit val ordering: O * @param elem a new element to add. * @return a new $coll containing `elem` and all the elements of this $coll. */ - def + (elem: A): TreeSet[A] = newSet(tree.update(elem, ())) + def + (elem: A): TreeSet[A] = newSet(RB.update(tree, elem, ())) /** A new `TreeSet` with the entry added is returned, * assuming that elem is not in the TreeSet. @@ -124,8 +124,8 @@ class TreeSet[A] private (tree: RedBlack.Tree[A, Unit])(implicit val ordering: O * @return a new $coll containing `elem` and all the elements of this $coll. */ def insert(elem: A): TreeSet[A] = { - assert(tree.lookup(elem).isEmpty) - newSet(tree.update(elem, ())) + assert(!RB.contains(tree, elem)) + newSet(RB.update(tree, elem, ())) } /** Creates a new `TreeSet` with the entry removed. @@ -134,31 +134,29 @@ class TreeSet[A] private (tree: RedBlack.Tree[A, Unit])(implicit val ordering: O * @return a new $coll containing all the elements of this $coll except `elem`. */ def - (elem:A): TreeSet[A] = - if (tree.lookup(elem).isEmpty) this - else newSet(tree delete elem) + if (!RB.contains(tree, elem)) this + else newSet(RB.delete(tree, elem)) /** Checks if this set contains element `elem`. * * @param elem the element to check for membership. * @return true, iff `elem` is contained in this set. */ - def contains(elem: A): Boolean = !lookup(tree, elem).isEmpty + def contains(elem: A): Boolean = RB.contains(tree, elem) /** Creates a new iterator over all elements contained in this * object. * * @return the new iterator */ - def iterator: Iterator[A] = tree.keyIterator + def iterator: Iterator[A] = RB.keyIterator(tree) - override def toStream: Stream[A] = tree.keyIterator.toStream - - override def foreach[U](f: A => U) = tree foreachKey f + override def foreach[U](f: A => U) = RB.foreachKey(tree, f) override def rangeImpl(from: Option[A], until: Option[A]): TreeSet[A] = { - val tree = this.tree.range(from, until) - newSet(tree) + val ntree = RB.range(tree, from, until) + newSet(ntree) } - override def firstKey = tree.first - override def lastKey = tree.last + override def firstKey = head + override def lastKey = last } diff --git a/test/files/scalacheck/redblack.scala b/test/files/scalacheck/redblack.scala index 78fb645ce8..5c52a27e38 100644 --- a/test/files/scalacheck/redblack.scala +++ b/test/files/scalacheck/redblack.scala @@ -8,7 +8,7 @@ Properties of a Red & Black Tree: A node is either red or black. The root is black. (This rule is used in some definitions and not others. Since the -root can always be changed from red to black but not necessarily vice-versa this +root can always be changed from red to black but not necessarily vice-versa this rule has little effect on analysis.) All leaves are black. Both children of every red node are black. @@ -21,17 +21,17 @@ abstract class RedBlackTest extends Properties("RedBlack") { def maximumSize = 5 import RedBlack._ - - def nodeAt[A](tree: Tree[String, A], n: Int): Option[(String, A)] = if (n < tree.iterator.size && n >= 0) - Some(tree.iterator.drop(n).next) + + def nodeAt[A](tree: Tree[String, A], n: Int): Option[(String, A)] = if (n < iterator(tree).size && n >= 0) + Some(iterator(tree).drop(n).next) else None - - def treeContains[A](tree: Tree[String, A], key: String) = tree.iterator.map(_._1) contains key - - def mkTree(level: Int, parentIsBlack: Boolean = false, label: String = ""): Gen[Tree[String, Int]] = + + def treeContains[A](tree: Tree[String, A], key: String) = iterator(tree).map(_._1) contains key + + def mkTree(level: Int, parentIsBlack: Boolean = false, label: String = ""): Gen[Tree[String, Int]] = if (level == 0) { - value(Empty.empty) + value(null) } else { for { oddOrEven <- choose(0, 2) @@ -41,7 +41,7 @@ abstract class RedBlackTest extends Properties("RedBlack") { left <- mkTree(nextLevel, !isRed, label + "L") right <- mkTree(nextLevel, !isRed, label + "R") } yield { - if (isRed) + if (isRed) RedTree(label + "N", 0, left, right) else BlackTree(label + "N", 0, left, right) @@ -52,11 +52,11 @@ abstract class RedBlackTest extends Properties("RedBlack") { depth <- choose(minimumSize, maximumSize + 1) tree <- mkTree(depth) } yield tree - + type ModifyParm def genParm(tree: Tree[String, Int]): Gen[ModifyParm] def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] - + def genInput: Gen[(Tree[String, Int], ModifyParm, Tree[String, Int])] = for { tree <- genTree parm <- genParm(tree) @@ -65,41 +65,41 @@ abstract class RedBlackTest extends Properties("RedBlack") { trait RedBlackInvariants { self: RedBlackTest => - + import RedBlack._ - - def rootIsBlack[A](t: Tree[String, A]) = t.isBlack - + + def rootIsBlack[A](t: Tree[String, A]) = isBlack(t) + def areAllLeavesBlack[A](t: Tree[String, A]): Boolean = t match { - case Empty.Instance => t.isBlack - case ne: NonEmpty[_, _] => List(ne.left, ne.right) forall areAllLeavesBlack + case null => isBlack(t) + case ne => List(ne.left, ne.right) forall areAllLeavesBlack } - + def areRedNodeChildrenBlack[A](t: Tree[String, A]): Boolean = t match { - case RedTree(_, _, left, right) => List(left, right) forall (t => t.isBlack && areRedNodeChildrenBlack(t)) + case RedTree(_, _, left, right) => List(left, right) forall (t => isBlack(t) && areRedNodeChildrenBlack(t)) case BlackTree(_, _, left, right) => List(left, right) forall areRedNodeChildrenBlack - case Empty.Instance => true + case null => true } - + def blackNodesToLeaves[A](t: Tree[String, A]): List[Int] = t match { - case Empty.Instance => List(1) + case null => List(1) case BlackTree(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves map (_ + 1) case RedTree(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves } - + def areBlackNodesToLeavesEqual[A](t: Tree[String, A]): Boolean = t match { - case Empty.Instance => true - case ne: NonEmpty[_, _] => + case null => true + case ne => ( - blackNodesToLeaves(ne).distinct.size == 1 - && areBlackNodesToLeavesEqual(ne.left) + blackNodesToLeaves(ne).distinct.size == 1 + && areBlackNodesToLeavesEqual(ne.left) && areBlackNodesToLeavesEqual(ne.right) ) } - - def orderIsPreserved[A](t: Tree[String, A]): Boolean = - t.iterator zip t.iterator.drop(1) forall { case (x, y) => x._1 < y._1 } - + + def orderIsPreserved[A](t: Tree[String, A]): Boolean = + iterator(t) zip iterator(t).drop(1) forall { case (x, y) => x._1 < y._1 } + def setup(invariant: Tree[String, Int] => Boolean) = forAll(genInput) { case (tree, parm, newTree) => invariant(newTree) } @@ -113,10 +113,10 @@ trait RedBlackInvariants { object TestInsert extends RedBlackTest with RedBlackInvariants { import RedBlack._ - + override type ModifyParm = Int - override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, tree.iterator.size + 1) - override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = tree update (generateKey(tree, parm), 0) + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size + 1) + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = update(tree, generateKey(tree, parm), 0) def generateKey(tree: Tree[String, Int], parm: ModifyParm): String = nodeAt(tree, parm) match { case Some((key, _)) => key.init.mkString + "MN" @@ -133,18 +133,18 @@ object TestInsert extends RedBlackTest with RedBlackInvariants { object TestModify extends RedBlackTest { import RedBlack._ - + def newValue = 1 override def minimumSize = 1 override type ModifyParm = Int - override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, tree.iterator.size) - override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map { - case (key, _) => tree update (key, newValue) + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size) + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map { + case (key, _) => update(tree, key, newValue) } getOrElse tree property("update modifies values") = forAll(genInput) { case (tree, parm, newTree) => nodeAt(tree,parm) forall { case (key, _) => - newTree.iterator contains (key, newValue) + iterator(newTree) contains (key, newValue) } } } @@ -154,11 +154,11 @@ object TestDelete extends RedBlackTest with RedBlackInvariants { override def minimumSize = 1 override type ModifyParm = Int - override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, tree.iterator.size) - override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map { - case (key, _) => tree delete key + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size) + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map { + case (key, _) => delete(tree, key) } getOrElse tree - + property("delete removes elements") = forAll(genInput) { case (tree, parm, newTree) => nodeAt(tree, parm) forall { case (key, _) => !treeContains(newTree, key) @@ -168,37 +168,37 @@ object TestDelete extends RedBlackTest with RedBlackInvariants { object TestRange extends RedBlackTest with RedBlackInvariants { import RedBlack._ - + override type ModifyParm = (Option[Int], Option[Int]) override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = for { - from <- choose(0, tree.iterator.size) - to <- choose(0, tree.iterator.size) suchThat (from <=) + from <- choose(0, iterator(tree).size) + to <- choose(0, iterator(tree).size) suchThat (from <=) optionalFrom <- oneOf(Some(from), None, Some(from)) // Double Some(n) to get around a bug optionalTo <- oneOf(Some(to), None, Some(to)) // Double Some(n) to get around a bug } yield (optionalFrom, optionalTo) - + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = { val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) - tree range (from, to) + range(tree, from, to) } - + property("range boundaries respected") = forAll(genInput) { case (tree, parm, newTree) => val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) - ("lower boundary" |: (from forall ( key => newTree.iterator.map(_._1) forall (key <=)))) && - ("upper boundary" |: (to forall ( key => newTree.iterator.map(_._1) forall (key >)))) + ("lower boundary" |: (from forall ( key => iterator(newTree).map(_._1) forall (key <=)))) && + ("upper boundary" |: (to forall ( key => iterator(newTree).map(_._1) forall (key >)))) } - + property("range returns all elements") = forAll(genInput) { case (tree, parm, newTree) => val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) - val filteredTree = (tree.iterator - .map(_._1) + val filteredTree = (iterator(tree) + .map(_._1) .filter(key => from forall (key >=)) .filter(key => to forall (key <)) .toList) - filteredTree == newTree.iterator.map(_._1).toList + filteredTree == iterator(newTree).map(_._1).toList } } } -- cgit v1.2.3 From 72ec0ac869a29fca9ea0d45a3f70f1e9e1babaaf Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Wed, 4 Jan 2012 17:10:20 +0100 Subject: Optimize foreach and iterators. --- .../scala/collection/immutable/RedBlack.scala | 108 +++++++++++++-------- .../scala/collection/immutable/TreeMap.scala | 5 +- .../scala/collection/immutable/TreeSet.scala | 2 +- test/files/scalacheck/treemap.scala | 16 +++ test/files/scalacheck/treeset.scala | 16 +++ 5 files changed, 103 insertions(+), 44 deletions(-) (limited to 'test/files/scalacheck') diff --git a/src/library/scala/collection/immutable/RedBlack.scala b/src/library/scala/collection/immutable/RedBlack.scala index 2537d043fd..6af6b6ef03 100644 --- a/src/library/scala/collection/immutable/RedBlack.scala +++ b/src/library/scala/collection/immutable/RedBlack.scala @@ -11,6 +11,7 @@ package scala.collection package immutable +import annotation.tailrec import annotation.meta.getter /** An object containing the RedBlack tree implementation used by for `TreeMaps` and `TreeSets`. @@ -37,7 +38,7 @@ object RedBlack { case tree => Some(tree.value) } - @annotation.tailrec + @tailrec def lookup[A, B](tree: Tree[A, B], x: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree eq null) null else { val cmp = ordering.compare(x, tree.key) if (cmp < 0) lookup(tree.left, x) @@ -64,18 +65,19 @@ object RedBlack { } def foreach[A, B, U](tree: Tree[A, B], f: ((A, B)) => U): Unit = if (tree ne null) { - foreach(tree.left, f) + if (tree.left ne null) foreach(tree.left, f) f((tree.key, tree.value)) - foreach(tree.right, f) + if (tree.right ne null) foreach(tree.right, f) } def foreachKey[A, U](tree: Tree[A, _], f: A => U): Unit = if (tree ne null) { - foreachKey(tree.left, f) + if (tree.left ne null) foreachKey(tree.left, f) f(tree.key) - foreachKey(tree.right, f) + if (tree.right ne null) foreachKey(tree.right, f) } - def iterator[A, B](tree: Tree[A, B]): Iterator[(A, B)] = if (tree eq null) Iterator.empty else new TreeIterator(tree) - def keyIterator[A, _](tree: Tree[A, _]): Iterator[A] = if (tree eq null) Iterator.empty else new TreeKeyIterator(tree) + def iterator[A, B](tree: Tree[A, B]): Iterator[(A, B)] = new EntriesIterator(tree) + def keysIterator[A, _](tree: Tree[A, _]): Iterator[A] = new KeysIterator(tree) + def valuesIterator[_, B](tree: Tree[_, B]): Iterator[B] = new ValuesIterator(tree) private[this] def balanceLeft[A, B, B1 >: B](isBlack: Boolean, z: A, zv: B, l: Tree[A, B1], d: Tree[A, B1]): Tree[A, B1] = { if (isRedTree(l) && isRedTree(l.left)) @@ -283,7 +285,7 @@ object RedBlack { @(inline @getter) final val left: Tree[A, B], @(inline @getter) final val right: Tree[A, B]) extends Serializable { - @(inline @getter) final val count: Int = 1 + RedBlack.count(left) + RedBlack.count(right) + final val count: Int = 1 + RedBlack.count(left) + RedBlack.count(right) def isBlack: Boolean def nth(n: Int): Tree[A, B] = { val count = RedBlack.count(left) @@ -322,53 +324,75 @@ object RedBlack { def unapply[A, B](t: BlackTree[A, B]) = Some((t.key, t.value, t.left, t.right)) } - private[this] class TreeIterator[A, B](tree: Tree[A, B]) extends Iterator[(A, B)] { + private[this] abstract class TreeIterator[A, B, R](tree: Tree[A, B]) extends Iterator[R] { + protected[this] def nextResult(tree: Tree[A, B]): R + override def hasNext: Boolean = next ne null - override def next: (A, B) = next match { + override def next: R = next match { case null => throw new NoSuchElementException("next on empty iterator") case tree => - addLeftMostBranchToPath(tree.right) - next = if (path.isEmpty) null else path.pop() - (tree.key, tree.value) + next = findNext(tree.right) + nextResult(tree) } - @annotation.tailrec - private[this] def addLeftMostBranchToPath(tree: Tree[A, B]) { - if (tree ne null) { - path.push(tree) - addLeftMostBranchToPath(tree.left) + @tailrec + private[this] def findNext(tree: Tree[A, B]): Tree[A, B] = { + if (tree eq null) popPath() + else if (tree.left eq null) tree + else { + pushPath(tree) + findNext(tree.left) } } - private[this] val path = mutable.ArrayStack.empty[Tree[A, B]] - addLeftMostBranchToPath(tree) - private[this] var next: Tree[A, B] = path.pop() - } - - private[this] class TreeKeyIterator[A](tree: Tree[A, _]) extends Iterator[A] { - override def hasNext: Boolean = next ne null - - override def next: A = next match { - case null => - throw new NoSuchElementException("next on empty iterator") - case tree => - addLeftMostBranchToPath(tree.right) - next = if (path.isEmpty) null else path.pop() - tree.key + private[this] def pushPath(tree: Tree[A, B]) { + try { + path(index) = tree + index += 1 + } catch { + case _: ArrayIndexOutOfBoundsException => + // Either the tree became unbalanced or we calculated the maximum height incorrectly. + // To avoid crashing the iterator we expand the path array. Obviously this should never + // happen... + // + // An exception handler is used instead of an if-condition to optimize the normal path. + assert(index >= path.length) + path :+= null + pushPath(tree) + } + } + private[this] def popPath(): Tree[A, B] = if (index == 0) null else { + index -= 1 + path(index) } - @annotation.tailrec - private[this] def addLeftMostBranchToPath(tree: Tree[A, _]) { - if (tree ne null) { - path.push(tree) - addLeftMostBranchToPath(tree.left) - } + private[this] var path = if (tree eq null) null else { + /* + * According to "Ralf Hinze. Constructing red-black trees" [http://www.cs.ox.ac.uk/ralf.hinze/publications/#P5] + * the maximum height of a red-black tree is 2*log_2(n + 2) - 2. + * + * According to {@see Integer#numberOfLeadingZeros} ceil(log_2(n)) = (32 - Integer.numberOfLeadingZeros(n - 1)) + * + * We also don't store the deepest nodes in the path so the maximum path length is further reduced by one. + */ + val maximumHeight = 2 * (32 - Integer.numberOfLeadingZeros(tree.count + 2 - 1)) - 2 - 1 + new Array[Tree[A, B]](maximumHeight) } + private[this] var index = 0 + private[this] var next: Tree[A, B] = findNext(tree) + } + + private[this] class EntriesIterator[A, B](tree: Tree[A, B]) extends TreeIterator[A, B, (A, B)](tree) { + override def nextResult(tree: Tree[A, B]) = (tree.key, tree.value) + } + + private[this] class KeysIterator[A, B](tree: Tree[A, B]) extends TreeIterator[A, B, A](tree) { + override def nextResult(tree: Tree[A, B]) = tree.key + } - private[this] val path = mutable.ArrayStack.empty[Tree[A, _]] - addLeftMostBranchToPath(tree) - private[this] var next: Tree[A, _] = path.pop() + private[this] class ValuesIterator[A, B](tree: Tree[A, B]) extends TreeIterator[A, B, B](tree) { + override def nextResult(tree: Tree[A, B]) = tree.value } } diff --git a/src/library/scala/collection/immutable/TreeMap.scala b/src/library/scala/collection/immutable/TreeMap.scala index 45e936444f..6e8cf625f4 100644 --- a/src/library/scala/collection/immutable/TreeMap.scala +++ b/src/library/scala/collection/immutable/TreeMap.scala @@ -196,7 +196,10 @@ class TreeMap[A, +B] private (tree: RedBlack.Tree[A, B])(implicit val ordering: * * @return the new iterator */ - def iterator: Iterator[(A, B)] = RB.iterator(tree) + override def iterator: Iterator[(A, B)] = RB.iterator(tree) + + override def keysIterator: Iterator[A] = RB.keysIterator(tree) + override def valuesIterator: Iterator[B] = RB.valuesIterator(tree) override def contains(key: A): Boolean = RB.contains(tree, key) override def isDefinedAt(key: A): Boolean = RB.contains(tree, key) diff --git a/src/library/scala/collection/immutable/TreeSet.scala b/src/library/scala/collection/immutable/TreeSet.scala index 00ebeab868..7c27e9f5b0 100644 --- a/src/library/scala/collection/immutable/TreeSet.scala +++ b/src/library/scala/collection/immutable/TreeSet.scala @@ -149,7 +149,7 @@ class TreeSet[A] private (tree: RedBlack.Tree[A, Unit])(implicit val ordering: O * * @return the new iterator */ - def iterator: Iterator[A] = RB.keyIterator(tree) + def iterator: Iterator[A] = RB.keysIterator(tree) override def foreach[U](f: A => U) = RB.foreachKey(tree, f) diff --git a/test/files/scalacheck/treemap.scala b/test/files/scalacheck/treemap.scala index 43d307600d..9970bb01aa 100644 --- a/test/files/scalacheck/treemap.scala +++ b/test/files/scalacheck/treemap.scala @@ -22,6 +22,22 @@ object Test extends Properties("TreeMap") { consistent } + property("worst-case tree height is iterable") = forAll(choose(0, 10), arbitrary[Boolean]) { (n: Int, even: Boolean) => + /* + * According to "Ralf Hinze. Constructing red-black trees" [http://www.cs.ox.ac.uk/ralf.hinze/publications/#P5] + * you can construct a skinny tree of height 2n by inserting the elements [1 .. 2^(n+1) - 2] and a tree of height + * 2n+1 by inserting the elements [1 .. 3 * 2^n - 2], both in reverse order. + * + * Since we allocate a fixed size buffer in the iterator (based on the tree size) we need to ensure + * it is big enough for these worst-case trees. + */ + val highest = if (even) (1 << (n+1)) - 2 else 3*(1 << n) - 2 + val values = (1 to highest).reverse + val subject = TreeMap(values zip values: _*) + val it = subject.iterator + try { while (it.hasNext) it.next; true } catch { case _ => false } + } + property("sorted") = forAll { (subject: TreeMap[Int, String]) => (subject.size >= 3) ==> { subject.zip(subject.tail).forall { case (x, y) => x._1 < y._1 } }} diff --git a/test/files/scalacheck/treeset.scala b/test/files/scalacheck/treeset.scala index 3cefef7040..87c3eb7108 100644 --- a/test/files/scalacheck/treeset.scala +++ b/test/files/scalacheck/treeset.scala @@ -18,6 +18,22 @@ object Test extends Properties("TreeSet") { consistent } + property("worst-case tree height is iterable") = forAll(choose(0, 10), arbitrary[Boolean]) { (n: Int, even: Boolean) => + /* + * According to "Ralf Hinze. Constructing red-black trees" [http://www.cs.ox.ac.uk/ralf.hinze/publications/#P5] + * you can construct a skinny tree of height 2n by inserting the elements [1 .. 2^(n+1) - 2] and a tree of height + * 2n+1 by inserting the elements [1 .. 3 * 2^n - 2], both in reverse order. + * + * Since we allocate a fixed size buffer in the iterator (based on the tree size) we need to ensure + * it is big enough for these worst-case trees. + */ + val highest = if (even) (1 << (n+1)) - 2 else 3*(1 << n) - 2 + val values = (1 to highest).reverse + val subject = TreeSet(values: _*) + val it = subject.iterator + try { while (it.hasNext) it.next; true } catch { case _ => false } + } + property("sorted") = forAll { (subject: TreeSet[Int]) => (subject.size >= 3) ==> { subject.zip(subject.tail).forall { case (x, y) => x < y } }} -- cgit v1.2.3 From f656142ddbcecfd3f8482e2b55067de3d0ebd3ce Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Fri, 6 Jan 2012 23:19:39 +0100 Subject: Restore old RedBlack class to maintain backwards compatibility. The class is marked as deprecated and no longer used by the TreeMap/TreeSet implementation but is restored in case it was used by anyone else (since it was not marked as private to the Scala collection library). Renamed RedBlack.{Tree,RedTree,BlackTree} to Node, RedNode, and BlackNode to work around name clash with RedBlack class. --- .../scala/collection/immutable/RedBlack.scala | 561 ++++++++++++++++----- .../scala/collection/immutable/TreeMap.scala | 2 +- .../scala/collection/immutable/TreeSet.scala | 4 +- test/files/scalacheck/redblack.scala | 56 +- 4 files changed, 452 insertions(+), 171 deletions(-) (limited to 'test/files/scalacheck') diff --git a/src/library/scala/collection/immutable/RedBlack.scala b/src/library/scala/collection/immutable/RedBlack.scala index 30d3ff37a3..37ff7a7f54 100644 --- a/src/library/scala/collection/immutable/RedBlack.scala +++ b/src/library/scala/collection/immutable/RedBlack.scala @@ -26,167 +26,167 @@ import annotation.meta.getter private[immutable] object RedBlack { - private def blacken[A, B](t: Tree[A, B]): Tree[A, B] = if (t eq null) null else t.black + def isBlack(tree: Node[_, _]) = (tree eq null) || isBlackNode(tree) + def isRedNode(tree: Node[_, _]) = tree.isInstanceOf[RedNode[_, _]] + def isBlackNode(tree: Node[_, _]) = tree.isInstanceOf[BlackNode[_, _]] - private def mkTree[A, B](isBlack: Boolean, k: A, v: B, l: Tree[A, B], r: Tree[A, B]) = - if (isBlack) BlackTree(k, v, l, r) else RedTree(k, v, l, r) - - def isBlack(tree: Tree[_, _]) = (tree eq null) || isBlackTree(tree) - def isRedTree(tree: Tree[_, _]) = tree.isInstanceOf[RedTree[_, _]] - def isBlackTree(tree: Tree[_, _]) = tree.isInstanceOf[BlackTree[_, _]] + def isEmpty(tree: Node[_, _]): Boolean = tree eq null - def isEmpty(tree: Tree[_, _]): Boolean = tree eq null - - def contains[A](tree: Tree[A, _], x: A)(implicit ordering: Ordering[A]): Boolean = lookup(tree, x) ne null - def get[A, B](tree: Tree[A, B], x: A)(implicit ordering: Ordering[A]): Option[B] = lookup(tree, x) match { + def contains[A](tree: Node[A, _], x: A)(implicit ordering: Ordering[A]): Boolean = lookup(tree, x) ne null + def get[A, B](tree: Node[A, B], x: A)(implicit ordering: Ordering[A]): Option[B] = lookup(tree, x) match { case null => None case tree => Some(tree.value) } @tailrec - def lookup[A, B](tree: Tree[A, B], x: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree eq null) null else { + def lookup[A, B](tree: Node[A, B], x: A)(implicit ordering: Ordering[A]): Node[A, B] = if (tree eq null) null else { val cmp = ordering.compare(x, tree.key) if (cmp < 0) lookup(tree.left, x) else if (cmp > 0) lookup(tree.right, x) else tree } - def count(tree: Tree[_, _]) = if (tree eq null) 0 else tree.count - def update[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(tree, k, v)) - def delete[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = blacken(del(tree, k)) - def range[A, B](tree: Tree[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] = blacken(rng(tree, from, until)) + def count(tree: Node[_, _]) = if (tree eq null) 0 else tree.count + def update[A, B, B1 >: B](tree: Node[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Node[A, B1] = blacken(upd(tree, k, v)) + def delete[A, B](tree: Node[A, B], k: A)(implicit ordering: Ordering[A]): Node[A, B] = blacken(del(tree, k)) + def range[A, B](tree: Node[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Node[A, B] = blacken(rng(tree, from, until)) - def smallest[A, B](tree: Tree[A, B]): Tree[A, B] = { + def smallest[A, B](tree: Node[A, B]): Node[A, B] = { if (tree eq null) throw new NoSuchElementException("empty map") var result = tree while (result.left ne null) result = result.left result } - def greatest[A, B](tree: Tree[A, B]): Tree[A, B] = { + def greatest[A, B](tree: Node[A, B]): Node[A, B] = { if (tree eq null) throw new NoSuchElementException("empty map") var result = tree while (result.right ne null) result = result.right result } - def foreach[A, B, U](tree: Tree[A, B], f: ((A, B)) => U): Unit = if (tree ne null) { + def foreach[A, B, U](tree: Node[A, B], f: ((A, B)) => U): Unit = if (tree ne null) { if (tree.left ne null) foreach(tree.left, f) f((tree.key, tree.value)) if (tree.right ne null) foreach(tree.right, f) } - def foreachKey[A, U](tree: Tree[A, _], f: A => U): Unit = if (tree ne null) { + def foreachKey[A, U](tree: Node[A, _], f: A => U): Unit = if (tree ne null) { if (tree.left ne null) foreachKey(tree.left, f) f(tree.key) if (tree.right ne null) foreachKey(tree.right, f) } - def iterator[A, B](tree: Tree[A, B]): Iterator[(A, B)] = new EntriesIterator(tree) - def keysIterator[A, _](tree: Tree[A, _]): Iterator[A] = new KeysIterator(tree) - def valuesIterator[_, B](tree: Tree[_, B]): Iterator[B] = new ValuesIterator(tree) + def iterator[A, B](tree: Node[A, B]): Iterator[(A, B)] = new EntriesIterator(tree) + def keysIterator[A, _](tree: Node[A, _]): Iterator[A] = new KeysIterator(tree) + def valuesIterator[_, B](tree: Node[_, B]): Iterator[B] = new ValuesIterator(tree) @tailrec - def nth[A, B](tree: Tree[A, B], n: Int): Tree[A, B] = { + def nth[A, B](tree: Node[A, B], n: Int): Node[A, B] = { val count = RedBlack.count(tree.left) if (n < count) nth(tree.left, n) else if (n > count) nth(tree.right, n - count - 1) else tree } - private[this] def balanceLeft[A, B, B1 >: B](isBlack: Boolean, z: A, zv: B, l: Tree[A, B1], d: Tree[A, B1]): Tree[A, B1] = { - if (isRedTree(l) && isRedTree(l.left)) - RedTree(l.key, l.value, BlackTree(l.left.key, l.left.value, l.left.left, l.left.right), BlackTree(z, zv, l.right, d)) - else if (isRedTree(l) && isRedTree(l.right)) - RedTree(l.right.key, l.right.value, BlackTree(l.key, l.value, l.left, l.right.left), BlackTree(z, zv, l.right.right, d)) + private def blacken[A, B](t: Node[A, B]): Node[A, B] = if (t eq null) null else t.black + + private def mkNode[A, B](isBlack: Boolean, k: A, v: B, l: Node[A, B], r: Node[A, B]) = + if (isBlack) BlackNode(k, v, l, r) else RedNode(k, v, l, r) + + private[this] def balanceLeft[A, B, B1 >: B](isBlack: Boolean, z: A, zv: B, l: Node[A, B1], d: Node[A, B1]): Node[A, B1] = { + if (isRedNode(l) && isRedNode(l.left)) + RedNode(l.key, l.value, BlackNode(l.left.key, l.left.value, l.left.left, l.left.right), BlackNode(z, zv, l.right, d)) + else if (isRedNode(l) && isRedNode(l.right)) + RedNode(l.right.key, l.right.value, BlackNode(l.key, l.value, l.left, l.right.left), BlackNode(z, zv, l.right.right, d)) else - mkTree(isBlack, z, zv, l, d) + mkNode(isBlack, z, zv, l, d) } - private[this] def balanceRight[A, B, B1 >: B](isBlack: Boolean, x: A, xv: B, a: Tree[A, B1], r: Tree[A, B1]): Tree[A, B1] = { - if (isRedTree(r) && isRedTree(r.left)) - RedTree(r.left.key, r.left.value, BlackTree(x, xv, a, r.left.left), BlackTree(r.key, r.value, r.left.right, r.right)) - else if (isRedTree(r) && isRedTree(r.right)) - RedTree(r.key, r.value, BlackTree(x, xv, a, r.left), BlackTree(r.right.key, r.right.value, r.right.left, r.right.right)) + private[this] def balanceRight[A, B, B1 >: B](isBlack: Boolean, x: A, xv: B, a: Node[A, B1], r: Node[A, B1]): Node[A, B1] = { + if (isRedNode(r) && isRedNode(r.left)) + RedNode(r.left.key, r.left.value, BlackNode(x, xv, a, r.left.left), BlackNode(r.key, r.value, r.left.right, r.right)) + else if (isRedNode(r) && isRedNode(r.right)) + RedNode(r.key, r.value, BlackNode(x, xv, a, r.left), BlackNode(r.right.key, r.right.value, r.right.left, r.right.right)) else - mkTree(isBlack, x, xv, a, r) + mkNode(isBlack, x, xv, a, r) } - private[this] def upd[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = if (tree eq null) { - RedTree(k, v, null, null) + private[this] def upd[A, B, B1 >: B](tree: Node[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Node[A, B1] = if (tree eq null) { + RedNode(k, v, null, null) } else { val cmp = ordering.compare(k, tree.key) if (cmp < 0) balanceLeft(tree.isBlack, tree.key, tree.value, upd(tree.left, k, v), tree.right) else if (cmp > 0) balanceRight(tree.isBlack, tree.key, tree.value, tree.left, upd(tree.right, k, v)) - else mkTree(tree.isBlack, k, v, tree.left, tree.right) + else mkNode(tree.isBlack, k, v, tree.left, tree.right) } // Based on Stefan Kahrs' Haskell version of Okasaki's Red&Black Trees - // http://www.cse.unsw.edu.au/~dons/data/RedBlackTree.html - private[this] def del[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree eq null) null else { - def balance(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tl)) { - if (isRedTree(tr)) { - RedTree(x, xv, tl.black, tr.black) - } else if (isRedTree(tl.left)) { - RedTree(tl.key, tl.value, tl.left.black, BlackTree(x, xv, tl.right, tr)) - } else if (isRedTree(tl.right)) { - RedTree(tl.right.key, tl.right.value, BlackTree(tl.key, tl.value, tl.left, tl.right.left), BlackTree(x, xv, tl.right.right, tr)) + // http://www.cse.unsw.edu.au/~dons/data/RedBlackNode.html + private[this] def del[A, B](tree: Node[A, B], k: A)(implicit ordering: Ordering[A]): Node[A, B] = if (tree eq null) null else { + def balance(x: A, xv: B, tl: Node[A, B], tr: Node[A, B]) = if (isRedNode(tl)) { + if (isRedNode(tr)) { + RedNode(x, xv, tl.black, tr.black) + } else if (isRedNode(tl.left)) { + RedNode(tl.key, tl.value, tl.left.black, BlackNode(x, xv, tl.right, tr)) + } else if (isRedNode(tl.right)) { + RedNode(tl.right.key, tl.right.value, BlackNode(tl.key, tl.value, tl.left, tl.right.left), BlackNode(x, xv, tl.right.right, tr)) } else { - BlackTree(x, xv, tl, tr) + BlackNode(x, xv, tl, tr) } - } else if (isRedTree(tr)) { - if (isRedTree(tr.right)) { - RedTree(tr.key, tr.value, BlackTree(x, xv, tl, tr.left), tr.right.black) - } else if (isRedTree(tr.left)) { - RedTree(tr.left.key, tr.left.value, BlackTree(x, xv, tl, tr.left.left), BlackTree(tr.key, tr.value, tr.left.right, tr.right)) + } else if (isRedNode(tr)) { + if (isRedNode(tr.right)) { + RedNode(tr.key, tr.value, BlackNode(x, xv, tl, tr.left), tr.right.black) + } else if (isRedNode(tr.left)) { + RedNode(tr.left.key, tr.left.value, BlackNode(x, xv, tl, tr.left.left), BlackNode(tr.key, tr.value, tr.left.right, tr.right)) } else { - BlackTree(x, xv, tl, tr) + BlackNode(x, xv, tl, tr) } } else { - BlackTree(x, xv, tl, tr) + BlackNode(x, xv, tl, tr) } - def subl(t: Tree[A, B]) = - if (t.isInstanceOf[BlackTree[_, _]]) t.red + def subl(t: Node[A, B]) = + if (t.isInstanceOf[BlackNode[_, _]]) t.red else sys.error("Defect: invariance violation; expected black, got "+t) - def balLeft(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tl)) { - RedTree(x, xv, tl.black, tr) - } else if (isBlackTree(tr)) { + def balLeft(x: A, xv: B, tl: Node[A, B], tr: Node[A, B]) = if (isRedNode(tl)) { + RedNode(x, xv, tl.black, tr) + } else if (isBlackNode(tr)) { balance(x, xv, tl, tr.red) - } else if (isRedTree(tr) && isBlackTree(tr.left)) { - RedTree(tr.left.key, tr.left.value, BlackTree(x, xv, tl, tr.left.left), balance(tr.key, tr.value, tr.left.right, subl(tr.right))) + } else if (isRedNode(tr) && isBlackNode(tr.left)) { + RedNode(tr.left.key, tr.left.value, BlackNode(x, xv, tl, tr.left.left), balance(tr.key, tr.value, tr.left.right, subl(tr.right))) } else { - sys.error("Defect: invariance violation at ") // TODO + sys.error("Defect: invariance violation") } - def balRight(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tr)) { - RedTree(x, xv, tl, tr.black) - } else if (isBlackTree(tl)) { + def balRight(x: A, xv: B, tl: Node[A, B], tr: Node[A, B]) = if (isRedNode(tr)) { + RedNode(x, xv, tl, tr.black) + } else if (isBlackNode(tl)) { balance(x, xv, tl.red, tr) - } else if (isRedTree(tl) && isBlackTree(tl.right)) { - RedTree(tl.right.key, tl.right.value, balance(tl.key, tl.value, subl(tl.left), tl.right.left), BlackTree(x, xv, tl.right.right, tr)) + } else if (isRedNode(tl) && isBlackNode(tl.right)) { + RedNode(tl.right.key, tl.right.value, balance(tl.key, tl.value, subl(tl.left), tl.right.left), BlackNode(x, xv, tl.right.right, tr)) } else { - sys.error("Defect: invariance violation at ") // TODO + sys.error("Defect: invariance violation") } - def delLeft = if (isBlackTree(tree.left)) balLeft(tree.key, tree.value, del(tree.left, k), tree.right) else RedTree(tree.key, tree.value, del(tree.left, k), tree.right) - def delRight = if (isBlackTree(tree.right)) balRight(tree.key, tree.value, tree.left, del(tree.right, k)) else RedTree(tree.key, tree.value, tree.left, del(tree.right, k)) - def append(tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] = if (tl eq null) { + def delLeft = if (isBlackNode(tree.left)) balLeft(tree.key, tree.value, del(tree.left, k), tree.right) else RedNode(tree.key, tree.value, del(tree.left, k), tree.right) + def delRight = if (isBlackNode(tree.right)) balRight(tree.key, tree.value, tree.left, del(tree.right, k)) else RedNode(tree.key, tree.value, tree.left, del(tree.right, k)) + def append(tl: Node[A, B], tr: Node[A, B]): Node[A, B] = if (tl eq null) { tr } else if (tr eq null) { tl - } else if (isRedTree(tl) && isRedTree(tr)) { + } else if (isRedNode(tl) && isRedNode(tr)) { val bc = append(tl.right, tr.left) - if (isRedTree(bc)) { - RedTree(bc.key, bc.value, RedTree(tl.key, tl.value, tl.left, bc.left), RedTree(tr.key, tr.value, bc.right, tr.right)) + if (isRedNode(bc)) { + RedNode(bc.key, bc.value, RedNode(tl.key, tl.value, tl.left, bc.left), RedNode(tr.key, tr.value, bc.right, tr.right)) } else { - RedTree(tl.key, tl.value, tl.left, RedTree(tr.key, tr.value, bc, tr.right)) + RedNode(tl.key, tl.value, tl.left, RedNode(tr.key, tr.value, bc, tr.right)) } - } else if (isBlackTree(tl) && isBlackTree(tr)) { + } else if (isBlackNode(tl) && isBlackNode(tr)) { val bc = append(tl.right, tr.left) - if (isRedTree(bc)) { - RedTree(bc.key, bc.value, BlackTree(tl.key, tl.value, tl.left, bc.left), BlackTree(tr.key, tr.value, bc.right, tr.right)) + if (isRedNode(bc)) { + RedNode(bc.key, bc.value, BlackNode(tl.key, tl.value, tl.left, bc.left), BlackNode(tr.key, tr.value, bc.right, tr.right)) } else { - balLeft(tl.key, tl.value, tl.left, BlackTree(tr.key, tr.value, bc, tr.right)) + balLeft(tl.key, tl.value, tl.left, BlackNode(tr.key, tr.value, bc, tr.right)) } - } else if (isRedTree(tr)) { - RedTree(tr.key, tr.value, append(tl, tr.left), tr.right) - } else if (isRedTree(tl)) { - RedTree(tl.key, tl.value, tl.left, append(tl.right, tr)) + } else if (isRedNode(tr)) { + RedNode(tr.key, tr.value, append(tl, tr.left), tr.right) + } else if (isRedNode(tl)) { + RedNode(tl.key, tl.value, tl.left, append(tl.right, tr)) } else { sys.error("unmatched tree on append: " + tl + ", " + tr) } @@ -197,7 +197,7 @@ object RedBlack { else append(tree.left, tree.right) } - private[this] def rng[A, B](tree: Tree[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] = { + private[this] def rng[A, B](tree: Node[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Node[A, B] = { if (tree eq null) return null if (from == None && until == None) return tree if (from != None && ordering.lt(tree.key, from.get)) return rng(tree.right, from, until); @@ -219,9 +219,9 @@ object RedBlack { // whether the zipper was traversed left-most or right-most. // If the trees were balanced, returns an empty zipper - private[this] def compareDepth[A, B](left: Tree[A, B], right: Tree[A, B]): (List[Tree[A, B]], Boolean, Boolean, Int) = { + private[this] def compareDepth[A, B](left: Node[A, B], right: Node[A, B]): (List[Node[A, B]], Boolean, Boolean, Int) = { // Once a side is found to be deeper, unzip it to the bottom - def unzip(zipper: List[Tree[A, B]], leftMost: Boolean): List[Tree[A, B]] = { + def unzip(zipper: List[Node[A, B]], leftMost: Boolean): List[Node[A, B]] = { val next = if (leftMost) zipper.head.left else zipper.head.right next match { case null => zipper @@ -231,25 +231,25 @@ object RedBlack { // Unzip left tree on the rightmost side and right tree on the leftmost side until one is // found to be deeper, or the bottom is reached - def unzipBoth(left: Tree[A, B], - right: Tree[A, B], - leftZipper: List[Tree[A, B]], - rightZipper: List[Tree[A, B]], - smallerDepth: Int): (List[Tree[A, B]], Boolean, Boolean, Int) = { - if (isBlackTree(left) && isBlackTree(right)) { + def unzipBoth(left: Node[A, B], + right: Node[A, B], + leftZipper: List[Node[A, B]], + rightZipper: List[Node[A, B]], + smallerDepth: Int): (List[Node[A, B]], Boolean, Boolean, Int) = { + if (isBlackNode(left) && isBlackNode(right)) { unzipBoth(left.right, right.left, left :: leftZipper, right :: rightZipper, smallerDepth + 1) - } else if (isRedTree(left) && isRedTree(right)) { + } else if (isRedNode(left) && isRedNode(right)) { unzipBoth(left.right, right.left, left :: leftZipper, right :: rightZipper, smallerDepth) - } else if (isRedTree(right)) { + } else if (isRedNode(right)) { unzipBoth(left, right.left, leftZipper, right :: rightZipper, smallerDepth) - } else if (isRedTree(left)) { + } else if (isRedNode(left)) { unzipBoth(left.right, right, left :: leftZipper, rightZipper, smallerDepth) } else if ((left eq null) && (right eq null)) { (Nil, true, false, smallerDepth) - } else if ((left eq null) && isBlackTree(right)) { + } else if ((left eq null) && isBlackNode(right)) { val leftMost = true (unzip(right :: rightZipper, leftMost), false, leftMost, smallerDepth) - } else if (isBlackTree(left) && (right eq null)) { + } else if (isBlackNode(left) && (right eq null)) { val leftMost = false (unzip(left :: leftZipper, leftMost), false, leftMost, smallerDepth) } else { @@ -258,10 +258,10 @@ object RedBlack { } unzipBoth(left, right, Nil, Nil, 0) } - private[this] def rebalance[A, B](tree: Tree[A, B], newLeft: Tree[A, B], newRight: Tree[A, B]) = { + private[this] def rebalance[A, B](tree: Node[A, B], newLeft: Node[A, B], newRight: Node[A, B]) = { // This is like drop(n-1), but only counting black nodes - def findDepth(zipper: List[Tree[A, B]], depth: Int): List[Tree[A, B]] = zipper match { - case head :: tail if isBlackTree(head) => + def findDepth(zipper: List[Node[A, B]], depth: Int): List[Node[A, B]] = zipper match { + case head :: tail if isBlackNode(head) => if (depth == 1) zipper else findDepth(tail, depth - 1) case _ :: tail => findDepth(tail, depth) case Nil => sys.error("Defect: unexpected empty zipper while computing range") @@ -274,15 +274,15 @@ object RedBlack { val (zipper, levelled, leftMost, smallerDepth) = compareDepth(blkNewLeft, blkNewRight) if (levelled) { - BlackTree(tree.key, tree.value, blkNewLeft, blkNewRight) + BlackNode(tree.key, tree.value, blkNewLeft, blkNewRight) } else { val zipFrom = findDepth(zipper, smallerDepth) val union = if (leftMost) { - RedTree(tree.key, tree.value, blkNewLeft, zipFrom.head) + RedNode(tree.key, tree.value, blkNewLeft, zipFrom.head) } else { - RedTree(tree.key, tree.value, zipFrom.head, blkNewRight) + RedNode(tree.key, tree.value, zipFrom.head, blkNewRight) } - val zippedTree = zipFrom.tail.foldLeft(union: Tree[A, B]) { (tree, node) => + val zippedTree = zipFrom.tail.foldLeft(union: Node[A, B]) { (tree, node) => if (leftMost) balanceLeft(node.isBlack, node.key, node.value, tree, node.right) else @@ -301,47 +301,47 @@ object RedBlack { * * An alternative is to implement the these classes using plain old Java code... */ - sealed abstract class Tree[A, +B]( + sealed abstract class Node[A, +B]( @(inline @getter) final val key: A, @(inline @getter) final val value: B, - @(inline @getter) final val left: Tree[A, B], - @(inline @getter) final val right: Tree[A, B]) + @(inline @getter) final val left: Node[A, B], + @(inline @getter) final val right: Node[A, B]) extends Serializable { final val count: Int = 1 + RedBlack.count(left) + RedBlack.count(right) def isBlack: Boolean - def black: Tree[A, B] - def red: Tree[A, B] + def black: Node[A, B] + def red: Node[A, B] } - final class RedTree[A, +B](key: A, + final class RedNode[A, +B](key: A, value: B, - left: Tree[A, B], - right: Tree[A, B]) extends Tree[A, B](key, value, left, right) { + left: Node[A, B], + right: Node[A, B]) extends Node[A, B](key, value, left, right) { override def isBlack = false - override def black = BlackTree(key, value, left, right) + override def black = BlackNode(key, value, left, right) override def red = this - override def toString = "RedTree(" + key + ", " + value + ", " + left + ", " + right + ")" + override def toString = "RedNode(" + key + ", " + value + ", " + left + ", " + right + ")" } - final class BlackTree[A, +B](key: A, + final class BlackNode[A, +B](key: A, value: B, - left: Tree[A, B], - right: Tree[A, B]) extends Tree[A, B](key, value, left, right) { + left: Node[A, B], + right: Node[A, B]) extends Node[A, B](key, value, left, right) { override def isBlack = true override def black = this - override def red = RedTree(key, value, left, right) - override def toString = "BlackTree(" + key + ", " + value + ", " + left + ", " + right + ")" + override def red = RedNode(key, value, left, right) + override def toString = "BlackNode(" + key + ", " + value + ", " + left + ", " + right + ")" } - object RedTree { - @inline def apply[A, B](key: A, value: B, left: Tree[A, B], right: Tree[A, B]) = new RedTree(key, value, left, right) - def unapply[A, B](t: RedTree[A, B]) = Some((t.key, t.value, t.left, t.right)) + object RedNode { + @inline def apply[A, B](key: A, value: B, left: Node[A, B], right: Node[A, B]) = new RedNode(key, value, left, right) + def unapply[A, B](t: RedNode[A, B]) = Some((t.key, t.value, t.left, t.right)) } - object BlackTree { - @inline def apply[A, B](key: A, value: B, left: Tree[A, B], right: Tree[A, B]) = new BlackTree(key, value, left, right) - def unapply[A, B](t: BlackTree[A, B]) = Some((t.key, t.value, t.left, t.right)) + object BlackNode { + @inline def apply[A, B](key: A, value: B, left: Node[A, B], right: Node[A, B]) = new BlackNode(key, value, left, right) + def unapply[A, B](t: BlackNode[A, B]) = Some((t.key, t.value, t.left, t.right)) } - private[this] abstract class TreeIterator[A, B, R](tree: Tree[A, B]) extends Iterator[R] { - protected[this] def nextResult(tree: Tree[A, B]): R + private[this] abstract class TreeIterator[A, B, R](tree: Node[A, B]) extends Iterator[R] { + protected[this] def nextResult(tree: Node[A, B]): R override def hasNext: Boolean = next ne null @@ -354,7 +354,7 @@ object RedBlack { } @tailrec - private[this] def findNext(tree: Tree[A, B]): Tree[A, B] = { + private[this] def findNext(tree: Node[A, B]): Node[A, B] = { if (tree eq null) popPath() else if (tree.left eq null) tree else { @@ -363,7 +363,7 @@ object RedBlack { } } - private[this] def pushPath(tree: Tree[A, B]) { + private[this] def pushPath(tree: Node[A, B]) { try { path(index) = tree index += 1 @@ -382,7 +382,7 @@ object RedBlack { pushPath(tree) } } - private[this] def popPath(): Tree[A, B] = if (index == 0) null else { + private[this] def popPath(): Node[A, B] = if (index == 0) null else { index -= 1 path(index) } @@ -397,21 +397,302 @@ object RedBlack { * We also don't store the deepest nodes in the path so the maximum path length is further reduced by one. */ val maximumHeight = 2 * (32 - Integer.numberOfLeadingZeros(tree.count + 2 - 1)) - 2 - 1 - new Array[Tree[A, B]](maximumHeight) + new Array[Node[A, B]](maximumHeight) } private[this] var index = 0 - private[this] var next: Tree[A, B] = findNext(tree) + private[this] var next: Node[A, B] = findNext(tree) + } + + private[this] class EntriesIterator[A, B](tree: Node[A, B]) extends TreeIterator[A, B, (A, B)](tree) { + override def nextResult(tree: Node[A, B]) = (tree.key, tree.value) } - private[this] class EntriesIterator[A, B](tree: Tree[A, B]) extends TreeIterator[A, B, (A, B)](tree) { - override def nextResult(tree: Tree[A, B]) = (tree.key, tree.value) + private[this] class KeysIterator[A, B](tree: Node[A, B]) extends TreeIterator[A, B, A](tree) { + override def nextResult(tree: Node[A, B]) = tree.key } - private[this] class KeysIterator[A, B](tree: Tree[A, B]) extends TreeIterator[A, B, A](tree) { - override def nextResult(tree: Tree[A, B]) = tree.key + private[this] class ValuesIterator[A, B](tree: Node[A, B]) extends TreeIterator[A, B, B](tree) { + override def nextResult(tree: Node[A, B]) = tree.value } +} + + +/** Old base class that was used by previous implementations of `TreeMaps` and `TreeSets`. + * + * Deprecated due to various performance bugs (see [[https://issues.scala-lang.org/browse/SI-5331 SI-5331]] for more information). + * + * @since 2.3 + */ +@deprecated("use `TreeMap` or `TreeSet` instead", "2.10") +@SerialVersionUID(8691885935445612921L) +abstract class RedBlack[A] extends Serializable { - private[this] class ValuesIterator[A, B](tree: Tree[A, B]) extends TreeIterator[A, B, B](tree) { - override def nextResult(tree: Tree[A, B]) = tree.value + def isSmaller(x: A, y: A): Boolean + + private def blacken[B](t: Tree[B]): Tree[B] = t match { + case RedTree(k, v, l, r) => BlackTree(k, v, l, r) + case t => t + } + private def mkTree[B](isBlack: Boolean, k: A, v: B, l: Tree[B], r: Tree[B]) = + if (isBlack) BlackTree(k, v, l, r) else RedTree(k, v, l, r) + + abstract class Tree[+B] extends Serializable { + def isEmpty: Boolean + def isBlack: Boolean + def lookup(x: A): Tree[B] + def update[B1 >: B](k: A, v: B1): Tree[B1] = blacken(upd(k, v)) + def delete(k: A): Tree[B] = blacken(del(k)) + def range(from: Option[A], until: Option[A]): Tree[B] = blacken(rng(from, until)) + def foreach[U](f: (A, B) => U) + def toStream: Stream[(A,B)] + def iterator: Iterator[(A, B)] + def upd[B1 >: B](k: A, v: B1): Tree[B1] + def del(k: A): Tree[B] + def smallest: NonEmpty[B] + def rng(from: Option[A], until: Option[A]): Tree[B] + def first : A + def last : A + def count : Int + } + abstract class NonEmpty[+B] extends Tree[B] with Serializable { + def isEmpty = false + def key: A + def value: B + def left: Tree[B] + def right: Tree[B] + def lookup(k: A): Tree[B] = + if (isSmaller(k, key)) left.lookup(k) + else if (isSmaller(key, k)) right.lookup(k) + else this + private[this] def balanceLeft[B1 >: B](isBlack: Boolean, z: A, zv: B, l: Tree[B1], d: Tree[B1])/*: NonEmpty[B1]*/ = l match { + case RedTree(y, yv, RedTree(x, xv, a, b), c) => + RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) + case RedTree(x, xv, a, RedTree(y, yv, b, c)) => + RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) + case _ => + mkTree(isBlack, z, zv, l, d) + } + private[this] def balanceRight[B1 >: B](isBlack: Boolean, x: A, xv: B, a: Tree[B1], r: Tree[B1])/*: NonEmpty[B1]*/ = r match { + case RedTree(z, zv, RedTree(y, yv, b, c), d) => + RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) + case RedTree(y, yv, b, RedTree(z, zv, c, d)) => + RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) + case _ => + mkTree(isBlack, x, xv, a, r) + } + def upd[B1 >: B](k: A, v: B1): Tree[B1] = { + if (isSmaller(k, key)) balanceLeft(isBlack, key, value, left.upd(k, v), right) + else if (isSmaller(key, k)) balanceRight(isBlack, key, value, left, right.upd(k, v)) + else mkTree(isBlack, k, v, left, right) + } + // Based on Stefan Kahrs' Haskell version of Okasaki's Red&Black Trees + // http://www.cse.unsw.edu.au/~dons/data/RedBlackTree.html + def del(k: A): Tree[B] = { + def balance(x: A, xv: B, tl: Tree[B], tr: Tree[B]) = (tl, tr) match { + case (RedTree(y, yv, a, b), RedTree(z, zv, c, d)) => + RedTree(x, xv, BlackTree(y, yv, a, b), BlackTree(z, zv, c, d)) + case (RedTree(y, yv, RedTree(z, zv, a, b), c), d) => + RedTree(y, yv, BlackTree(z, zv, a, b), BlackTree(x, xv, c, d)) + case (RedTree(y, yv, a, RedTree(z, zv, b, c)), d) => + RedTree(z, zv, BlackTree(y, yv, a, b), BlackTree(x, xv, c, d)) + case (a, RedTree(y, yv, b, RedTree(z, zv, c, d))) => + RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) + case (a, RedTree(y, yv, RedTree(z, zv, b, c), d)) => + RedTree(z, zv, BlackTree(x, xv, a, b), BlackTree(y, yv, c, d)) + case (a, b) => + BlackTree(x, xv, a, b) + } + def subl(t: Tree[B]) = t match { + case BlackTree(x, xv, a, b) => RedTree(x, xv, a, b) + case _ => sys.error("Defect: invariance violation; expected black, got "+t) + } + def balLeft(x: A, xv: B, tl: Tree[B], tr: Tree[B]) = (tl, tr) match { + case (RedTree(y, yv, a, b), c) => + RedTree(x, xv, BlackTree(y, yv, a, b), c) + case (bl, BlackTree(y, yv, a, b)) => + balance(x, xv, bl, RedTree(y, yv, a, b)) + case (bl, RedTree(y, yv, BlackTree(z, zv, a, b), c)) => + RedTree(z, zv, BlackTree(x, xv, bl, a), balance(y, yv, b, subl(c))) + case _ => sys.error("Defect: invariance violation at "+right) + } + def balRight(x: A, xv: B, tl: Tree[B], tr: Tree[B]) = (tl, tr) match { + case (a, RedTree(y, yv, b, c)) => + RedTree(x, xv, a, BlackTree(y, yv, b, c)) + case (BlackTree(y, yv, a, b), bl) => + balance(x, xv, RedTree(y, yv, a, b), bl) + case (RedTree(y, yv, a, BlackTree(z, zv, b, c)), bl) => + RedTree(z, zv, balance(y, yv, subl(a), b), BlackTree(x, xv, c, bl)) + case _ => sys.error("Defect: invariance violation at "+left) + } + def delLeft = left match { + case _: BlackTree[_] => balLeft(key, value, left.del(k), right) + case _ => RedTree(key, value, left.del(k), right) + } + def delRight = right match { + case _: BlackTree[_] => balRight(key, value, left, right.del(k)) + case _ => RedTree(key, value, left, right.del(k)) + } + def append(tl: Tree[B], tr: Tree[B]): Tree[B] = (tl, tr) match { + case (Empty, t) => t + case (t, Empty) => t + case (RedTree(x, xv, a, b), RedTree(y, yv, c, d)) => + append(b, c) match { + case RedTree(z, zv, bb, cc) => RedTree(z, zv, RedTree(x, xv, a, bb), RedTree(y, yv, cc, d)) + case bc => RedTree(x, xv, a, RedTree(y, yv, bc, d)) + } + case (BlackTree(x, xv, a, b), BlackTree(y, yv, c, d)) => + append(b, c) match { + case RedTree(z, zv, bb, cc) => RedTree(z, zv, BlackTree(x, xv, a, bb), BlackTree(y, yv, cc, d)) + case bc => balLeft(x, xv, a, BlackTree(y, yv, bc, d)) + } + case (a, RedTree(x, xv, b, c)) => RedTree(x, xv, append(a, b), c) + case (RedTree(x, xv, a, b), c) => RedTree(x, xv, a, append(b, c)) + } + // RedBlack is neither A : Ordering[A], nor A <% Ordered[A] + k match { + case _ if isSmaller(k, key) => delLeft + case _ if isSmaller(key, k) => delRight + case _ => append(left, right) + } + } + + def smallest: NonEmpty[B] = if (left.isEmpty) this else left.smallest + + def toStream: Stream[(A,B)] = + left.toStream ++ Stream((key,value)) ++ right.toStream + + def iterator: Iterator[(A, B)] = + left.iterator ++ Iterator.single(Pair(key, value)) ++ right.iterator + + def foreach[U](f: (A, B) => U) { + left foreach f + f(key, value) + right foreach f + } + + override def rng(from: Option[A], until: Option[A]): Tree[B] = { + if (from == None && until == None) return this + if (from != None && isSmaller(key, from.get)) return right.rng(from, until); + if (until != None && (isSmaller(until.get,key) || !isSmaller(key,until.get))) + return left.rng(from, until); + val newLeft = left.rng(from, None) + val newRight = right.rng(None, until) + if ((newLeft eq left) && (newRight eq right)) this + else if (newLeft eq Empty) newRight.upd(key, value); + else if (newRight eq Empty) newLeft.upd(key, value); + else rebalance(newLeft, newRight) + } + + // The zipper returned might have been traversed left-most (always the left child) + // or right-most (always the right child). Left trees are traversed right-most, + // and right trees are traversed leftmost. + + // Returns the zipper for the side with deepest black nodes depth, a flag + // indicating whether the trees were unbalanced at all, and a flag indicating + // whether the zipper was traversed left-most or right-most. + + // If the trees were balanced, returns an empty zipper + private[this] def compareDepth(left: Tree[B], right: Tree[B]): (List[NonEmpty[B]], Boolean, Boolean, Int) = { + // Once a side is found to be deeper, unzip it to the bottom + def unzip(zipper: List[NonEmpty[B]], leftMost: Boolean): List[NonEmpty[B]] = { + val next = if (leftMost) zipper.head.left else zipper.head.right + next match { + case node: NonEmpty[_] => unzip(node :: zipper, leftMost) + case Empty => zipper + } + } + + // Unzip left tree on the rightmost side and right tree on the leftmost side until one is + // found to be deeper, or the bottom is reached + def unzipBoth(left: Tree[B], + right: Tree[B], + leftZipper: List[NonEmpty[B]], + rightZipper: List[NonEmpty[B]], + smallerDepth: Int): (List[NonEmpty[B]], Boolean, Boolean, Int) = (left, right) match { + case (l @ BlackTree(_, _, _, _), r @ BlackTree(_, _, _, _)) => + unzipBoth(l.right, r.left, l :: leftZipper, r :: rightZipper, smallerDepth + 1) + case (l @ RedTree(_, _, _, _), r @ RedTree(_, _, _, _)) => + unzipBoth(l.right, r.left, l :: leftZipper, r :: rightZipper, smallerDepth) + case (_, r @ RedTree(_, _, _, _)) => + unzipBoth(left, r.left, leftZipper, r :: rightZipper, smallerDepth) + case (l @ RedTree(_, _, _, _), _) => + unzipBoth(l.right, right, l :: leftZipper, rightZipper, smallerDepth) + case (Empty, Empty) => + (Nil, true, false, smallerDepth) + case (Empty, r @ BlackTree(_, _, _, _)) => + val leftMost = true + (unzip(r :: rightZipper, leftMost), false, leftMost, smallerDepth) + case (l @ BlackTree(_, _, _, _), Empty) => + val leftMost = false + (unzip(l :: leftZipper, leftMost), false, leftMost, smallerDepth) + } + unzipBoth(left, right, Nil, Nil, 0) + } + + private[this] def rebalance(newLeft: Tree[B], newRight: Tree[B]) = { + // This is like drop(n-1), but only counting black nodes + def findDepth(zipper: List[NonEmpty[B]], depth: Int): List[NonEmpty[B]] = zipper match { + case BlackTree(_, _, _, _) :: tail => + if (depth == 1) zipper else findDepth(tail, depth - 1) + case _ :: tail => findDepth(tail, depth) + case Nil => sys.error("Defect: unexpected empty zipper while computing range") + } + + // Blackening the smaller tree avoids balancing problems on union; + // this can't be done later, though, or it would change the result of compareDepth + val blkNewLeft = blacken(newLeft) + val blkNewRight = blacken(newRight) + val (zipper, levelled, leftMost, smallerDepth) = compareDepth(blkNewLeft, blkNewRight) + + if (levelled) { + BlackTree(key, value, blkNewLeft, blkNewRight) + } else { + val zipFrom = findDepth(zipper, smallerDepth) + val union = if (leftMost) { + RedTree(key, value, blkNewLeft, zipFrom.head) + } else { + RedTree(key, value, zipFrom.head, blkNewRight) + } + val zippedTree = zipFrom.tail.foldLeft(union: Tree[B]) { (tree, node) => + if (leftMost) + balanceLeft(node.isBlack, node.key, node.value, tree, node.right) + else + balanceRight(node.isBlack, node.key, node.value, node.left, tree) + } + zippedTree + } + } + def first = if (left .isEmpty) key else left.first + def last = if (right.isEmpty) key else right.last + def count = 1 + left.count + right.count + } + case object Empty extends Tree[Nothing] { + def isEmpty = true + def isBlack = true + def lookup(k: A): Tree[Nothing] = this + def upd[B](k: A, v: B): Tree[B] = RedTree(k, v, Empty, Empty) + def del(k: A): Tree[Nothing] = this + def smallest: NonEmpty[Nothing] = throw new NoSuchElementException("empty map") + def iterator: Iterator[(A, Nothing)] = Iterator.empty + def toStream: Stream[(A,Nothing)] = Stream.empty + + def foreach[U](f: (A, Nothing) => U) {} + + def rng(from: Option[A], until: Option[A]) = this + def first = throw new NoSuchElementException("empty map") + def last = throw new NoSuchElementException("empty map") + def count = 0 + } + case class RedTree[+B](override val key: A, + override val value: B, + override val left: Tree[B], + override val right: Tree[B]) extends NonEmpty[B] { + def isBlack = false + } + case class BlackTree[+B](override val key: A, + override val value: B, + override val left: Tree[B], + override val right: Tree[B]) extends NonEmpty[B] { + def isBlack = true } } diff --git a/src/library/scala/collection/immutable/TreeMap.scala b/src/library/scala/collection/immutable/TreeMap.scala index 65e42ad061..50244ef21d 100644 --- a/src/library/scala/collection/immutable/TreeMap.scala +++ b/src/library/scala/collection/immutable/TreeMap.scala @@ -45,7 +45,7 @@ object TreeMap extends ImmutableSortedMapFactory[TreeMap] { * @define mayNotTerminateInf * @define willNotTerminateInf */ -class TreeMap[A, +B] private (tree: RedBlack.Tree[A, B])(implicit val ordering: Ordering[A]) +class TreeMap[A, +B] private (tree: RedBlack.Node[A, B])(implicit val ordering: Ordering[A]) extends SortedMap[A, B] with SortedMapLike[A, B, TreeMap[A, B]] with MapLike[A, B, TreeMap[A, B]] diff --git a/src/library/scala/collection/immutable/TreeSet.scala b/src/library/scala/collection/immutable/TreeSet.scala index f7ceafdf8f..899ef0e5eb 100644 --- a/src/library/scala/collection/immutable/TreeSet.scala +++ b/src/library/scala/collection/immutable/TreeSet.scala @@ -47,7 +47,7 @@ object TreeSet extends ImmutableSortedSetFactory[TreeSet] { * @define willNotTerminateInf */ @SerialVersionUID(-5685982407650748405L) -class TreeSet[A] private (tree: RedBlack.Tree[A, Unit])(implicit val ordering: Ordering[A]) +class TreeSet[A] private (tree: RedBlack.Node[A, Unit])(implicit val ordering: Ordering[A]) extends SortedSet[A] with SortedSetLike[A, TreeSet[A]] with Serializable { import immutable.{RedBlack => RB} @@ -105,7 +105,7 @@ class TreeSet[A] private (tree: RedBlack.Tree[A, Unit])(implicit val ordering: O def this()(implicit ordering: Ordering[A]) = this(null)(ordering) - private def newSet(t: RedBlack.Tree[A, Unit]) = new TreeSet[A](t) + private def newSet(t: RedBlack.Node[A, Unit]) = new TreeSet[A](t) /** A factory to create empty sets of the same type of keys. */ diff --git a/test/files/scalacheck/redblack.scala b/test/files/scalacheck/redblack.scala index 5c52a27e38..83d3ca0c1f 100644 --- a/test/files/scalacheck/redblack.scala +++ b/test/files/scalacheck/redblack.scala @@ -22,14 +22,14 @@ abstract class RedBlackTest extends Properties("RedBlack") { import RedBlack._ - def nodeAt[A](tree: Tree[String, A], n: Int): Option[(String, A)] = if (n < iterator(tree).size && n >= 0) + def nodeAt[A](tree: Node[String, A], n: Int): Option[(String, A)] = if (n < iterator(tree).size && n >= 0) Some(iterator(tree).drop(n).next) else None - def treeContains[A](tree: Tree[String, A], key: String) = iterator(tree).map(_._1) contains key + def treeContains[A](tree: Node[String, A], key: String) = iterator(tree).map(_._1) contains key - def mkTree(level: Int, parentIsBlack: Boolean = false, label: String = ""): Gen[Tree[String, Int]] = + def mkTree(level: Int, parentIsBlack: Boolean = false, label: String = ""): Gen[Node[String, Int]] = if (level == 0) { value(null) } else { @@ -42,9 +42,9 @@ abstract class RedBlackTest extends Properties("RedBlack") { right <- mkTree(nextLevel, !isRed, label + "R") } yield { if (isRed) - RedTree(label + "N", 0, left, right) + RedNode(label + "N", 0, left, right) else - BlackTree(label + "N", 0, left, right) + BlackNode(label + "N", 0, left, right) } } @@ -54,10 +54,10 @@ abstract class RedBlackTest extends Properties("RedBlack") { } yield tree type ModifyParm - def genParm(tree: Tree[String, Int]): Gen[ModifyParm] - def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] + def genParm(tree: Node[String, Int]): Gen[ModifyParm] + def modify(tree: Node[String, Int], parm: ModifyParm): Node[String, Int] - def genInput: Gen[(Tree[String, Int], ModifyParm, Tree[String, Int])] = for { + def genInput: Gen[(Node[String, Int], ModifyParm, Node[String, Int])] = for { tree <- genTree parm <- genParm(tree) } yield (tree, parm, modify(tree, parm)) @@ -68,26 +68,26 @@ trait RedBlackInvariants { import RedBlack._ - def rootIsBlack[A](t: Tree[String, A]) = isBlack(t) + def rootIsBlack[A](t: Node[String, A]) = isBlack(t) - def areAllLeavesBlack[A](t: Tree[String, A]): Boolean = t match { + def areAllLeavesBlack[A](t: Node[String, A]): Boolean = t match { case null => isBlack(t) case ne => List(ne.left, ne.right) forall areAllLeavesBlack } - def areRedNodeChildrenBlack[A](t: Tree[String, A]): Boolean = t match { - case RedTree(_, _, left, right) => List(left, right) forall (t => isBlack(t) && areRedNodeChildrenBlack(t)) - case BlackTree(_, _, left, right) => List(left, right) forall areRedNodeChildrenBlack + def areRedNodeChildrenBlack[A](t: Node[String, A]): Boolean = t match { + case RedNode(_, _, left, right) => List(left, right) forall (t => isBlack(t) && areRedNodeChildrenBlack(t)) + case BlackNode(_, _, left, right) => List(left, right) forall areRedNodeChildrenBlack case null => true } - def blackNodesToLeaves[A](t: Tree[String, A]): List[Int] = t match { + def blackNodesToLeaves[A](t: Node[String, A]): List[Int] = t match { case null => List(1) - case BlackTree(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves map (_ + 1) - case RedTree(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves + case BlackNode(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves map (_ + 1) + case RedNode(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves } - def areBlackNodesToLeavesEqual[A](t: Tree[String, A]): Boolean = t match { + def areBlackNodesToLeavesEqual[A](t: Node[String, A]): Boolean = t match { case null => true case ne => ( @@ -97,10 +97,10 @@ trait RedBlackInvariants { ) } - def orderIsPreserved[A](t: Tree[String, A]): Boolean = + def orderIsPreserved[A](t: Node[String, A]): Boolean = iterator(t) zip iterator(t).drop(1) forall { case (x, y) => x._1 < y._1 } - def setup(invariant: Tree[String, Int] => Boolean) = forAll(genInput) { case (tree, parm, newTree) => + def setup(invariant: Node[String, Int] => Boolean) = forAll(genInput) { case (tree, parm, newTree) => invariant(newTree) } @@ -115,10 +115,10 @@ object TestInsert extends RedBlackTest with RedBlackInvariants { import RedBlack._ override type ModifyParm = Int - override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size + 1) - override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = update(tree, generateKey(tree, parm), 0) + override def genParm(tree: Node[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size + 1) + override def modify(tree: Node[String, Int], parm: ModifyParm): Node[String, Int] = update(tree, generateKey(tree, parm), 0) - def generateKey(tree: Tree[String, Int], parm: ModifyParm): String = nodeAt(tree, parm) match { + def generateKey(tree: Node[String, Int], parm: ModifyParm): String = nodeAt(tree, parm) match { case Some((key, _)) => key.init.mkString + "MN" case None => nodeAt(tree, parm - 1) match { case Some((key, _)) => key.init.mkString + "RN" @@ -137,8 +137,8 @@ object TestModify extends RedBlackTest { def newValue = 1 override def minimumSize = 1 override type ModifyParm = Int - override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size) - override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map { + override def genParm(tree: Node[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size) + override def modify(tree: Node[String, Int], parm: ModifyParm): Node[String, Int] = nodeAt(tree, parm) map { case (key, _) => update(tree, key, newValue) } getOrElse tree @@ -154,8 +154,8 @@ object TestDelete extends RedBlackTest with RedBlackInvariants { override def minimumSize = 1 override type ModifyParm = Int - override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size) - override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map { + override def genParm(tree: Node[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size) + override def modify(tree: Node[String, Int], parm: ModifyParm): Node[String, Int] = nodeAt(tree, parm) map { case (key, _) => delete(tree, key) } getOrElse tree @@ -170,14 +170,14 @@ object TestRange extends RedBlackTest with RedBlackInvariants { import RedBlack._ override type ModifyParm = (Option[Int], Option[Int]) - override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = for { + override def genParm(tree: Node[String, Int]): Gen[ModifyParm] = for { from <- choose(0, iterator(tree).size) to <- choose(0, iterator(tree).size) suchThat (from <=) optionalFrom <- oneOf(Some(from), None, Some(from)) // Double Some(n) to get around a bug optionalTo <- oneOf(Some(to), None, Some(to)) // Double Some(n) to get around a bug } yield (optionalFrom, optionalTo) - override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = { + override def modify(tree: Node[String, Int], parm: ModifyParm): Node[String, Int] = { val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) range(tree, from, to) -- cgit v1.2.3 From 288874d80856317744c582f1468d7af420d9e0ee Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Sat, 7 Jan 2012 15:26:40 +0100 Subject: Renamed object RedBlack to RedBlackTree. This more clearly separates the new implementation from the now deprecated abstract class RedBlack and avoids naming conflicts for the member classes. --- .../scala/collection/immutable/RedBlack.scala | 406 -------------------- .../scala/collection/immutable/RedBlackTree.scala | 416 +++++++++++++++++++++ .../scala/collection/immutable/TreeMap.scala | 5 +- .../scala/collection/immutable/TreeSet.scala | 7 +- test/files/scalacheck/redblack.scala | 113 +++--- test/files/scalacheck/redblacktree.scala | 212 +++++++++++ 6 files changed, 690 insertions(+), 469 deletions(-) create mode 100644 src/library/scala/collection/immutable/RedBlackTree.scala create mode 100644 test/files/scalacheck/redblacktree.scala (limited to 'test/files/scalacheck') diff --git a/src/library/scala/collection/immutable/RedBlack.scala b/src/library/scala/collection/immutable/RedBlack.scala index 37ff7a7f54..83eeaa45ee 100644 --- a/src/library/scala/collection/immutable/RedBlack.scala +++ b/src/library/scala/collection/immutable/RedBlack.scala @@ -11,412 +11,6 @@ package scala.collection package immutable -import annotation.tailrec -import annotation.meta.getter - -/** An object containing the RedBlack tree implementation used by for `TreeMaps` and `TreeSets`. - * - * Implementation note: since efficiency is important for data structures this implementation - * uses null to represent empty trees. This also means pattern matching cannot - * easily be used. The API represented by the RedBlack object tries to hide these optimizations - * behind a reasonably clean API. - * - * @since 2.3 - */ -private[immutable] -object RedBlack { - - def isBlack(tree: Node[_, _]) = (tree eq null) || isBlackNode(tree) - def isRedNode(tree: Node[_, _]) = tree.isInstanceOf[RedNode[_, _]] - def isBlackNode(tree: Node[_, _]) = tree.isInstanceOf[BlackNode[_, _]] - - def isEmpty(tree: Node[_, _]): Boolean = tree eq null - - def contains[A](tree: Node[A, _], x: A)(implicit ordering: Ordering[A]): Boolean = lookup(tree, x) ne null - def get[A, B](tree: Node[A, B], x: A)(implicit ordering: Ordering[A]): Option[B] = lookup(tree, x) match { - case null => None - case tree => Some(tree.value) - } - - @tailrec - def lookup[A, B](tree: Node[A, B], x: A)(implicit ordering: Ordering[A]): Node[A, B] = if (tree eq null) null else { - val cmp = ordering.compare(x, tree.key) - if (cmp < 0) lookup(tree.left, x) - else if (cmp > 0) lookup(tree.right, x) - else tree - } - - def count(tree: Node[_, _]) = if (tree eq null) 0 else tree.count - def update[A, B, B1 >: B](tree: Node[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Node[A, B1] = blacken(upd(tree, k, v)) - def delete[A, B](tree: Node[A, B], k: A)(implicit ordering: Ordering[A]): Node[A, B] = blacken(del(tree, k)) - def range[A, B](tree: Node[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Node[A, B] = blacken(rng(tree, from, until)) - - def smallest[A, B](tree: Node[A, B]): Node[A, B] = { - if (tree eq null) throw new NoSuchElementException("empty map") - var result = tree - while (result.left ne null) result = result.left - result - } - def greatest[A, B](tree: Node[A, B]): Node[A, B] = { - if (tree eq null) throw new NoSuchElementException("empty map") - var result = tree - while (result.right ne null) result = result.right - result - } - - def foreach[A, B, U](tree: Node[A, B], f: ((A, B)) => U): Unit = if (tree ne null) { - if (tree.left ne null) foreach(tree.left, f) - f((tree.key, tree.value)) - if (tree.right ne null) foreach(tree.right, f) - } - def foreachKey[A, U](tree: Node[A, _], f: A => U): Unit = if (tree ne null) { - if (tree.left ne null) foreachKey(tree.left, f) - f(tree.key) - if (tree.right ne null) foreachKey(tree.right, f) - } - - def iterator[A, B](tree: Node[A, B]): Iterator[(A, B)] = new EntriesIterator(tree) - def keysIterator[A, _](tree: Node[A, _]): Iterator[A] = new KeysIterator(tree) - def valuesIterator[_, B](tree: Node[_, B]): Iterator[B] = new ValuesIterator(tree) - - @tailrec - def nth[A, B](tree: Node[A, B], n: Int): Node[A, B] = { - val count = RedBlack.count(tree.left) - if (n < count) nth(tree.left, n) - else if (n > count) nth(tree.right, n - count - 1) - else tree - } - - private def blacken[A, B](t: Node[A, B]): Node[A, B] = if (t eq null) null else t.black - - private def mkNode[A, B](isBlack: Boolean, k: A, v: B, l: Node[A, B], r: Node[A, B]) = - if (isBlack) BlackNode(k, v, l, r) else RedNode(k, v, l, r) - - private[this] def balanceLeft[A, B, B1 >: B](isBlack: Boolean, z: A, zv: B, l: Node[A, B1], d: Node[A, B1]): Node[A, B1] = { - if (isRedNode(l) && isRedNode(l.left)) - RedNode(l.key, l.value, BlackNode(l.left.key, l.left.value, l.left.left, l.left.right), BlackNode(z, zv, l.right, d)) - else if (isRedNode(l) && isRedNode(l.right)) - RedNode(l.right.key, l.right.value, BlackNode(l.key, l.value, l.left, l.right.left), BlackNode(z, zv, l.right.right, d)) - else - mkNode(isBlack, z, zv, l, d) - } - private[this] def balanceRight[A, B, B1 >: B](isBlack: Boolean, x: A, xv: B, a: Node[A, B1], r: Node[A, B1]): Node[A, B1] = { - if (isRedNode(r) && isRedNode(r.left)) - RedNode(r.left.key, r.left.value, BlackNode(x, xv, a, r.left.left), BlackNode(r.key, r.value, r.left.right, r.right)) - else if (isRedNode(r) && isRedNode(r.right)) - RedNode(r.key, r.value, BlackNode(x, xv, a, r.left), BlackNode(r.right.key, r.right.value, r.right.left, r.right.right)) - else - mkNode(isBlack, x, xv, a, r) - } - private[this] def upd[A, B, B1 >: B](tree: Node[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Node[A, B1] = if (tree eq null) { - RedNode(k, v, null, null) - } else { - val cmp = ordering.compare(k, tree.key) - if (cmp < 0) balanceLeft(tree.isBlack, tree.key, tree.value, upd(tree.left, k, v), tree.right) - else if (cmp > 0) balanceRight(tree.isBlack, tree.key, tree.value, tree.left, upd(tree.right, k, v)) - else mkNode(tree.isBlack, k, v, tree.left, tree.right) - } - - // Based on Stefan Kahrs' Haskell version of Okasaki's Red&Black Trees - // http://www.cse.unsw.edu.au/~dons/data/RedBlackNode.html - private[this] def del[A, B](tree: Node[A, B], k: A)(implicit ordering: Ordering[A]): Node[A, B] = if (tree eq null) null else { - def balance(x: A, xv: B, tl: Node[A, B], tr: Node[A, B]) = if (isRedNode(tl)) { - if (isRedNode(tr)) { - RedNode(x, xv, tl.black, tr.black) - } else if (isRedNode(tl.left)) { - RedNode(tl.key, tl.value, tl.left.black, BlackNode(x, xv, tl.right, tr)) - } else if (isRedNode(tl.right)) { - RedNode(tl.right.key, tl.right.value, BlackNode(tl.key, tl.value, tl.left, tl.right.left), BlackNode(x, xv, tl.right.right, tr)) - } else { - BlackNode(x, xv, tl, tr) - } - } else if (isRedNode(tr)) { - if (isRedNode(tr.right)) { - RedNode(tr.key, tr.value, BlackNode(x, xv, tl, tr.left), tr.right.black) - } else if (isRedNode(tr.left)) { - RedNode(tr.left.key, tr.left.value, BlackNode(x, xv, tl, tr.left.left), BlackNode(tr.key, tr.value, tr.left.right, tr.right)) - } else { - BlackNode(x, xv, tl, tr) - } - } else { - BlackNode(x, xv, tl, tr) - } - def subl(t: Node[A, B]) = - if (t.isInstanceOf[BlackNode[_, _]]) t.red - else sys.error("Defect: invariance violation; expected black, got "+t) - - def balLeft(x: A, xv: B, tl: Node[A, B], tr: Node[A, B]) = if (isRedNode(tl)) { - RedNode(x, xv, tl.black, tr) - } else if (isBlackNode(tr)) { - balance(x, xv, tl, tr.red) - } else if (isRedNode(tr) && isBlackNode(tr.left)) { - RedNode(tr.left.key, tr.left.value, BlackNode(x, xv, tl, tr.left.left), balance(tr.key, tr.value, tr.left.right, subl(tr.right))) - } else { - sys.error("Defect: invariance violation") - } - def balRight(x: A, xv: B, tl: Node[A, B], tr: Node[A, B]) = if (isRedNode(tr)) { - RedNode(x, xv, tl, tr.black) - } else if (isBlackNode(tl)) { - balance(x, xv, tl.red, tr) - } else if (isRedNode(tl) && isBlackNode(tl.right)) { - RedNode(tl.right.key, tl.right.value, balance(tl.key, tl.value, subl(tl.left), tl.right.left), BlackNode(x, xv, tl.right.right, tr)) - } else { - sys.error("Defect: invariance violation") - } - def delLeft = if (isBlackNode(tree.left)) balLeft(tree.key, tree.value, del(tree.left, k), tree.right) else RedNode(tree.key, tree.value, del(tree.left, k), tree.right) - def delRight = if (isBlackNode(tree.right)) balRight(tree.key, tree.value, tree.left, del(tree.right, k)) else RedNode(tree.key, tree.value, tree.left, del(tree.right, k)) - def append(tl: Node[A, B], tr: Node[A, B]): Node[A, B] = if (tl eq null) { - tr - } else if (tr eq null) { - tl - } else if (isRedNode(tl) && isRedNode(tr)) { - val bc = append(tl.right, tr.left) - if (isRedNode(bc)) { - RedNode(bc.key, bc.value, RedNode(tl.key, tl.value, tl.left, bc.left), RedNode(tr.key, tr.value, bc.right, tr.right)) - } else { - RedNode(tl.key, tl.value, tl.left, RedNode(tr.key, tr.value, bc, tr.right)) - } - } else if (isBlackNode(tl) && isBlackNode(tr)) { - val bc = append(tl.right, tr.left) - if (isRedNode(bc)) { - RedNode(bc.key, bc.value, BlackNode(tl.key, tl.value, tl.left, bc.left), BlackNode(tr.key, tr.value, bc.right, tr.right)) - } else { - balLeft(tl.key, tl.value, tl.left, BlackNode(tr.key, tr.value, bc, tr.right)) - } - } else if (isRedNode(tr)) { - RedNode(tr.key, tr.value, append(tl, tr.left), tr.right) - } else if (isRedNode(tl)) { - RedNode(tl.key, tl.value, tl.left, append(tl.right, tr)) - } else { - sys.error("unmatched tree on append: " + tl + ", " + tr) - } - - val cmp = ordering.compare(k, tree.key) - if (cmp < 0) delLeft - else if (cmp > 0) delRight - else append(tree.left, tree.right) - } - - private[this] def rng[A, B](tree: Node[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Node[A, B] = { - if (tree eq null) return null - if (from == None && until == None) return tree - if (from != None && ordering.lt(tree.key, from.get)) return rng(tree.right, from, until); - if (until != None && ordering.lteq(until.get, tree.key)) return rng(tree.left, from, until); - val newLeft = rng(tree.left, from, None) - val newRight = rng(tree.right, None, until) - if ((newLeft eq tree.left) && (newRight eq tree.right)) tree - else if (newLeft eq null) upd(newRight, tree.key, tree.value); - else if (newRight eq null) upd(newLeft, tree.key, tree.value); - else rebalance(tree, newLeft, newRight) - } - - // The zipper returned might have been traversed left-most (always the left child) - // or right-most (always the right child). Left trees are traversed right-most, - // and right trees are traversed leftmost. - - // Returns the zipper for the side with deepest black nodes depth, a flag - // indicating whether the trees were unbalanced at all, and a flag indicating - // whether the zipper was traversed left-most or right-most. - - // If the trees were balanced, returns an empty zipper - private[this] def compareDepth[A, B](left: Node[A, B], right: Node[A, B]): (List[Node[A, B]], Boolean, Boolean, Int) = { - // Once a side is found to be deeper, unzip it to the bottom - def unzip(zipper: List[Node[A, B]], leftMost: Boolean): List[Node[A, B]] = { - val next = if (leftMost) zipper.head.left else zipper.head.right - next match { - case null => zipper - case node => unzip(node :: zipper, leftMost) - } - } - - // Unzip left tree on the rightmost side and right tree on the leftmost side until one is - // found to be deeper, or the bottom is reached - def unzipBoth(left: Node[A, B], - right: Node[A, B], - leftZipper: List[Node[A, B]], - rightZipper: List[Node[A, B]], - smallerDepth: Int): (List[Node[A, B]], Boolean, Boolean, Int) = { - if (isBlackNode(left) && isBlackNode(right)) { - unzipBoth(left.right, right.left, left :: leftZipper, right :: rightZipper, smallerDepth + 1) - } else if (isRedNode(left) && isRedNode(right)) { - unzipBoth(left.right, right.left, left :: leftZipper, right :: rightZipper, smallerDepth) - } else if (isRedNode(right)) { - unzipBoth(left, right.left, leftZipper, right :: rightZipper, smallerDepth) - } else if (isRedNode(left)) { - unzipBoth(left.right, right, left :: leftZipper, rightZipper, smallerDepth) - } else if ((left eq null) && (right eq null)) { - (Nil, true, false, smallerDepth) - } else if ((left eq null) && isBlackNode(right)) { - val leftMost = true - (unzip(right :: rightZipper, leftMost), false, leftMost, smallerDepth) - } else if (isBlackNode(left) && (right eq null)) { - val leftMost = false - (unzip(left :: leftZipper, leftMost), false, leftMost, smallerDepth) - } else { - sys.error("unmatched trees in unzip: " + left + ", " + right) - } - } - unzipBoth(left, right, Nil, Nil, 0) - } - private[this] def rebalance[A, B](tree: Node[A, B], newLeft: Node[A, B], newRight: Node[A, B]) = { - // This is like drop(n-1), but only counting black nodes - def findDepth(zipper: List[Node[A, B]], depth: Int): List[Node[A, B]] = zipper match { - case head :: tail if isBlackNode(head) => - if (depth == 1) zipper else findDepth(tail, depth - 1) - case _ :: tail => findDepth(tail, depth) - case Nil => sys.error("Defect: unexpected empty zipper while computing range") - } - - // Blackening the smaller tree avoids balancing problems on union; - // this can't be done later, though, or it would change the result of compareDepth - val blkNewLeft = blacken(newLeft) - val blkNewRight = blacken(newRight) - val (zipper, levelled, leftMost, smallerDepth) = compareDepth(blkNewLeft, blkNewRight) - - if (levelled) { - BlackNode(tree.key, tree.value, blkNewLeft, blkNewRight) - } else { - val zipFrom = findDepth(zipper, smallerDepth) - val union = if (leftMost) { - RedNode(tree.key, tree.value, blkNewLeft, zipFrom.head) - } else { - RedNode(tree.key, tree.value, zipFrom.head, blkNewRight) - } - val zippedTree = zipFrom.tail.foldLeft(union: Node[A, B]) { (tree, node) => - if (leftMost) - balanceLeft(node.isBlack, node.key, node.value, tree, node.right) - else - balanceRight(node.isBlack, node.key, node.value, node.left, tree) - } - zippedTree - } - } - - /* - * Forcing direct fields access using the @inline annotation helps speed up - * various operations (especially smallest/greatest and update/delete). - * - * Unfortunately the direct field access is not guaranteed to work (but - * works on the current implementation of the Scala compiler). - * - * An alternative is to implement the these classes using plain old Java code... - */ - sealed abstract class Node[A, +B]( - @(inline @getter) final val key: A, - @(inline @getter) final val value: B, - @(inline @getter) final val left: Node[A, B], - @(inline @getter) final val right: Node[A, B]) - extends Serializable { - final val count: Int = 1 + RedBlack.count(left) + RedBlack.count(right) - def isBlack: Boolean - def black: Node[A, B] - def red: Node[A, B] - } - final class RedNode[A, +B](key: A, - value: B, - left: Node[A, B], - right: Node[A, B]) extends Node[A, B](key, value, left, right) { - override def isBlack = false - override def black = BlackNode(key, value, left, right) - override def red = this - override def toString = "RedNode(" + key + ", " + value + ", " + left + ", " + right + ")" - } - final class BlackNode[A, +B](key: A, - value: B, - left: Node[A, B], - right: Node[A, B]) extends Node[A, B](key, value, left, right) { - override def isBlack = true - override def black = this - override def red = RedNode(key, value, left, right) - override def toString = "BlackNode(" + key + ", " + value + ", " + left + ", " + right + ")" - } - - object RedNode { - @inline def apply[A, B](key: A, value: B, left: Node[A, B], right: Node[A, B]) = new RedNode(key, value, left, right) - def unapply[A, B](t: RedNode[A, B]) = Some((t.key, t.value, t.left, t.right)) - } - object BlackNode { - @inline def apply[A, B](key: A, value: B, left: Node[A, B], right: Node[A, B]) = new BlackNode(key, value, left, right) - def unapply[A, B](t: BlackNode[A, B]) = Some((t.key, t.value, t.left, t.right)) - } - - private[this] abstract class TreeIterator[A, B, R](tree: Node[A, B]) extends Iterator[R] { - protected[this] def nextResult(tree: Node[A, B]): R - - override def hasNext: Boolean = next ne null - - override def next: R = next match { - case null => - throw new NoSuchElementException("next on empty iterator") - case tree => - next = findNext(tree.right) - nextResult(tree) - } - - @tailrec - private[this] def findNext(tree: Node[A, B]): Node[A, B] = { - if (tree eq null) popPath() - else if (tree.left eq null) tree - else { - pushPath(tree) - findNext(tree.left) - } - } - - private[this] def pushPath(tree: Node[A, B]) { - try { - path(index) = tree - index += 1 - } catch { - case _: ArrayIndexOutOfBoundsException => - /* - * Either the tree became unbalanced or we calculated the maximum height incorrectly. - * To avoid crashing the iterator we expand the path array. Obviously this should never - * happen... - * - * An exception handler is used instead of an if-condition to optimize the normal path. - * This makes a large difference in iteration speed! - */ - assert(index >= path.length) - path :+= null - pushPath(tree) - } - } - private[this] def popPath(): Node[A, B] = if (index == 0) null else { - index -= 1 - path(index) - } - - private[this] var path = if (tree eq null) null else { - /* - * According to "Ralf Hinze. Constructing red-black trees" [http://www.cs.ox.ac.uk/ralf.hinze/publications/#P5] - * the maximum height of a red-black tree is 2*log_2(n + 2) - 2. - * - * According to {@see Integer#numberOfLeadingZeros} ceil(log_2(n)) = (32 - Integer.numberOfLeadingZeros(n - 1)) - * - * We also don't store the deepest nodes in the path so the maximum path length is further reduced by one. - */ - val maximumHeight = 2 * (32 - Integer.numberOfLeadingZeros(tree.count + 2 - 1)) - 2 - 1 - new Array[Node[A, B]](maximumHeight) - } - private[this] var index = 0 - private[this] var next: Node[A, B] = findNext(tree) - } - - private[this] class EntriesIterator[A, B](tree: Node[A, B]) extends TreeIterator[A, B, (A, B)](tree) { - override def nextResult(tree: Node[A, B]) = (tree.key, tree.value) - } - - private[this] class KeysIterator[A, B](tree: Node[A, B]) extends TreeIterator[A, B, A](tree) { - override def nextResult(tree: Node[A, B]) = tree.key - } - - private[this] class ValuesIterator[A, B](tree: Node[A, B]) extends TreeIterator[A, B, B](tree) { - override def nextResult(tree: Node[A, B]) = tree.value - } -} - - /** Old base class that was used by previous implementations of `TreeMaps` and `TreeSets`. * * Deprecated due to various performance bugs (see [[https://issues.scala-lang.org/browse/SI-5331 SI-5331]] for more information). diff --git a/src/library/scala/collection/immutable/RedBlackTree.scala b/src/library/scala/collection/immutable/RedBlackTree.scala new file mode 100644 index 0000000000..ebd88ce3fe --- /dev/null +++ b/src/library/scala/collection/immutable/RedBlackTree.scala @@ -0,0 +1,416 @@ +/* __ *\ +** ________ ___ / / ___ Scala API ** +** / __/ __// _ | / / / _ | (c) 2005-2011, LAMP/EPFL ** +** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ ** +** /____/\___/_/ |_/____/_/ | | ** +** |/ ** +\* */ + + + +package scala.collection +package immutable + +import annotation.tailrec +import annotation.meta.getter + +/** An object containing the RedBlack tree implementation used by for `TreeMaps` and `TreeSets`. + * + * Implementation note: since efficiency is important for data structures this implementation + * uses null to represent empty trees. This also means pattern matching cannot + * easily be used. The API represented by the RedBlackTree object tries to hide these + * optimizations behind a reasonably clean API. + * + * @since 2.10 + */ +private[immutable] +object RedBlackTree { + + def isEmpty(tree: Tree[_, _]): Boolean = tree eq null + + def contains[A](tree: Tree[A, _], x: A)(implicit ordering: Ordering[A]): Boolean = lookup(tree, x) ne null + def get[A, B](tree: Tree[A, B], x: A)(implicit ordering: Ordering[A]): Option[B] = lookup(tree, x) match { + case null => None + case tree => Some(tree.value) + } + + @tailrec + def lookup[A, B](tree: Tree[A, B], x: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree eq null) null else { + val cmp = ordering.compare(x, tree.key) + if (cmp < 0) lookup(tree.left, x) + else if (cmp > 0) lookup(tree.right, x) + else tree + } + + def count(tree: Tree[_, _]) = if (tree eq null) 0 else tree.count + def update[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(tree, k, v)) + def delete[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = blacken(del(tree, k)) + def range[A, B](tree: Tree[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] = blacken(rng(tree, from, until)) + + def smallest[A, B](tree: Tree[A, B]): Tree[A, B] = { + if (tree eq null) throw new NoSuchElementException("empty map") + var result = tree + while (result.left ne null) result = result.left + result + } + def greatest[A, B](tree: Tree[A, B]): Tree[A, B] = { + if (tree eq null) throw new NoSuchElementException("empty map") + var result = tree + while (result.right ne null) result = result.right + result + } + + def foreach[A, B, U](tree: Tree[A, B], f: ((A, B)) => U): Unit = if (tree ne null) { + if (tree.left ne null) foreach(tree.left, f) + f((tree.key, tree.value)) + if (tree.right ne null) foreach(tree.right, f) + } + def foreachKey[A, U](tree: Tree[A, _], f: A => U): Unit = if (tree ne null) { + if (tree.left ne null) foreachKey(tree.left, f) + f(tree.key) + if (tree.right ne null) foreachKey(tree.right, f) + } + + def iterator[A, B](tree: Tree[A, B]): Iterator[(A, B)] = new EntriesIterator(tree) + def keysIterator[A, _](tree: Tree[A, _]): Iterator[A] = new KeysIterator(tree) + def valuesIterator[_, B](tree: Tree[_, B]): Iterator[B] = new ValuesIterator(tree) + + @tailrec + def nth[A, B](tree: Tree[A, B], n: Int): Tree[A, B] = { + val count = RedBlackTree.count(tree.left) + if (n < count) nth(tree.left, n) + else if (n > count) nth(tree.right, n - count - 1) + else tree + } + + def isBlack(tree: Tree[_, _]) = (tree eq null) || isBlackTree(tree) + + private[this] def isRedTree(tree: Tree[_, _]) = tree.isInstanceOf[RedTree[_, _]] + private[this] def isBlackTree(tree: Tree[_, _]) = tree.isInstanceOf[BlackTree[_, _]] + + private[this] def blacken[A, B](t: Tree[A, B]): Tree[A, B] = if (t eq null) null else t.black + + private[this] def mkTree[A, B](isBlack: Boolean, k: A, v: B, l: Tree[A, B], r: Tree[A, B]) = + if (isBlack) BlackTree(k, v, l, r) else RedTree(k, v, l, r) + + private[this] def balanceLeft[A, B, B1 >: B](isBlack: Boolean, z: A, zv: B, l: Tree[A, B1], d: Tree[A, B1]): Tree[A, B1] = { + if (isRedTree(l) && isRedTree(l.left)) + RedTree(l.key, l.value, BlackTree(l.left.key, l.left.value, l.left.left, l.left.right), BlackTree(z, zv, l.right, d)) + else if (isRedTree(l) && isRedTree(l.right)) + RedTree(l.right.key, l.right.value, BlackTree(l.key, l.value, l.left, l.right.left), BlackTree(z, zv, l.right.right, d)) + else + mkTree(isBlack, z, zv, l, d) + } + private[this] def balanceRight[A, B, B1 >: B](isBlack: Boolean, x: A, xv: B, a: Tree[A, B1], r: Tree[A, B1]): Tree[A, B1] = { + if (isRedTree(r) && isRedTree(r.left)) + RedTree(r.left.key, r.left.value, BlackTree(x, xv, a, r.left.left), BlackTree(r.key, r.value, r.left.right, r.right)) + else if (isRedTree(r) && isRedTree(r.right)) + RedTree(r.key, r.value, BlackTree(x, xv, a, r.left), BlackTree(r.right.key, r.right.value, r.right.left, r.right.right)) + else + mkTree(isBlack, x, xv, a, r) + } + private[this] def upd[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = if (tree eq null) { + RedTree(k, v, null, null) + } else { + val cmp = ordering.compare(k, tree.key) + if (cmp < 0) balanceLeft(isBlackTree(tree), tree.key, tree.value, upd(tree.left, k, v), tree.right) + else if (cmp > 0) balanceRight(isBlackTree(tree), tree.key, tree.value, tree.left, upd(tree.right, k, v)) + else mkTree(isBlackTree(tree), k, v, tree.left, tree.right) + } + + // Based on Stefan Kahrs' Haskell version of Okasaki's Red&Black Trees + // http://www.cse.unsw.edu.au/~dons/data/RedBlackTree.html + private[this] def del[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = if (tree eq null) null else { + def balance(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tl)) { + if (isRedTree(tr)) { + RedTree(x, xv, tl.black, tr.black) + } else if (isRedTree(tl.left)) { + RedTree(tl.key, tl.value, tl.left.black, BlackTree(x, xv, tl.right, tr)) + } else if (isRedTree(tl.right)) { + RedTree(tl.right.key, tl.right.value, BlackTree(tl.key, tl.value, tl.left, tl.right.left), BlackTree(x, xv, tl.right.right, tr)) + } else { + BlackTree(x, xv, tl, tr) + } + } else if (isRedTree(tr)) { + if (isRedTree(tr.right)) { + RedTree(tr.key, tr.value, BlackTree(x, xv, tl, tr.left), tr.right.black) + } else if (isRedTree(tr.left)) { + RedTree(tr.left.key, tr.left.value, BlackTree(x, xv, tl, tr.left.left), BlackTree(tr.key, tr.value, tr.left.right, tr.right)) + } else { + BlackTree(x, xv, tl, tr) + } + } else { + BlackTree(x, xv, tl, tr) + } + def subl(t: Tree[A, B]) = + if (t.isInstanceOf[BlackTree[_, _]]) t.red + else sys.error("Defect: invariance violation; expected black, got "+t) + + def balLeft(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tl)) { + RedTree(x, xv, tl.black, tr) + } else if (isBlackTree(tr)) { + balance(x, xv, tl, tr.red) + } else if (isRedTree(tr) && isBlackTree(tr.left)) { + RedTree(tr.left.key, tr.left.value, BlackTree(x, xv, tl, tr.left.left), balance(tr.key, tr.value, tr.left.right, subl(tr.right))) + } else { + sys.error("Defect: invariance violation at ") // TODO + } + def balRight(x: A, xv: B, tl: Tree[A, B], tr: Tree[A, B]) = if (isRedTree(tr)) { + RedTree(x, xv, tl, tr.black) + } else if (isBlackTree(tl)) { + balance(x, xv, tl.red, tr) + } else if (isRedTree(tl) && isBlackTree(tl.right)) { + RedTree(tl.right.key, tl.right.value, balance(tl.key, tl.value, subl(tl.left), tl.right.left), BlackTree(x, xv, tl.right.right, tr)) + } else { + sys.error("Defect: invariance violation at ") // TODO + } + def delLeft = if (isBlackTree(tree.left)) balLeft(tree.key, tree.value, del(tree.left, k), tree.right) else RedTree(tree.key, tree.value, del(tree.left, k), tree.right) + def delRight = if (isBlackTree(tree.right)) balRight(tree.key, tree.value, tree.left, del(tree.right, k)) else RedTree(tree.key, tree.value, tree.left, del(tree.right, k)) + def append(tl: Tree[A, B], tr: Tree[A, B]): Tree[A, B] = if (tl eq null) { + tr + } else if (tr eq null) { + tl + } else if (isRedTree(tl) && isRedTree(tr)) { + val bc = append(tl.right, tr.left) + if (isRedTree(bc)) { + RedTree(bc.key, bc.value, RedTree(tl.key, tl.value, tl.left, bc.left), RedTree(tr.key, tr.value, bc.right, tr.right)) + } else { + RedTree(tl.key, tl.value, tl.left, RedTree(tr.key, tr.value, bc, tr.right)) + } + } else if (isBlackTree(tl) && isBlackTree(tr)) { + val bc = append(tl.right, tr.left) + if (isRedTree(bc)) { + RedTree(bc.key, bc.value, BlackTree(tl.key, tl.value, tl.left, bc.left), BlackTree(tr.key, tr.value, bc.right, tr.right)) + } else { + balLeft(tl.key, tl.value, tl.left, BlackTree(tr.key, tr.value, bc, tr.right)) + } + } else if (isRedTree(tr)) { + RedTree(tr.key, tr.value, append(tl, tr.left), tr.right) + } else if (isRedTree(tl)) { + RedTree(tl.key, tl.value, tl.left, append(tl.right, tr)) + } else { + sys.error("unmatched tree on append: " + tl + ", " + tr) + } + + val cmp = ordering.compare(k, tree.key) + if (cmp < 0) delLeft + else if (cmp > 0) delRight + else append(tree.left, tree.right) + } + + private[this] def rng[A, B](tree: Tree[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] = { + if (tree eq null) return null + if (from == None && until == None) return tree + if (from != None && ordering.lt(tree.key, from.get)) return rng(tree.right, from, until); + if (until != None && ordering.lteq(until.get, tree.key)) return rng(tree.left, from, until); + val newLeft = rng(tree.left, from, None) + val newRight = rng(tree.right, None, until) + if ((newLeft eq tree.left) && (newRight eq tree.right)) tree + else if (newLeft eq null) upd(newRight, tree.key, tree.value); + else if (newRight eq null) upd(newLeft, tree.key, tree.value); + else rebalance(tree, newLeft, newRight) + } + + // The zipper returned might have been traversed left-most (always the left child) + // or right-most (always the right child). Left trees are traversed right-most, + // and right trees are traversed leftmost. + + // Returns the zipper for the side with deepest black nodes depth, a flag + // indicating whether the trees were unbalanced at all, and a flag indicating + // whether the zipper was traversed left-most or right-most. + + // If the trees were balanced, returns an empty zipper + private[this] def compareDepth[A, B](left: Tree[A, B], right: Tree[A, B]): (List[Tree[A, B]], Boolean, Boolean, Int) = { + // Once a side is found to be deeper, unzip it to the bottom + def unzip(zipper: List[Tree[A, B]], leftMost: Boolean): List[Tree[A, B]] = { + val next = if (leftMost) zipper.head.left else zipper.head.right + next match { + case null => zipper + case node => unzip(node :: zipper, leftMost) + } + } + + // Unzip left tree on the rightmost side and right tree on the leftmost side until one is + // found to be deeper, or the bottom is reached + def unzipBoth(left: Tree[A, B], + right: Tree[A, B], + leftZipper: List[Tree[A, B]], + rightZipper: List[Tree[A, B]], + smallerDepth: Int): (List[Tree[A, B]], Boolean, Boolean, Int) = { + if (isBlackTree(left) && isBlackTree(right)) { + unzipBoth(left.right, right.left, left :: leftZipper, right :: rightZipper, smallerDepth + 1) + } else if (isRedTree(left) && isRedTree(right)) { + unzipBoth(left.right, right.left, left :: leftZipper, right :: rightZipper, smallerDepth) + } else if (isRedTree(right)) { + unzipBoth(left, right.left, leftZipper, right :: rightZipper, smallerDepth) + } else if (isRedTree(left)) { + unzipBoth(left.right, right, left :: leftZipper, rightZipper, smallerDepth) + } else if ((left eq null) && (right eq null)) { + (Nil, true, false, smallerDepth) + } else if ((left eq null) && isBlackTree(right)) { + val leftMost = true + (unzip(right :: rightZipper, leftMost), false, leftMost, smallerDepth) + } else if (isBlackTree(left) && (right eq null)) { + val leftMost = false + (unzip(left :: leftZipper, leftMost), false, leftMost, smallerDepth) + } else { + sys.error("unmatched trees in unzip: " + left + ", " + right) + } + } + unzipBoth(left, right, Nil, Nil, 0) + } + + private[this] def rebalance[A, B](tree: Tree[A, B], newLeft: Tree[A, B], newRight: Tree[A, B]) = { + // This is like drop(n-1), but only counting black nodes + def findDepth(zipper: List[Tree[A, B]], depth: Int): List[Tree[A, B]] = zipper match { + case head :: tail if isBlackTree(head) => + if (depth == 1) zipper else findDepth(tail, depth - 1) + case _ :: tail => findDepth(tail, depth) + case Nil => sys.error("Defect: unexpected empty zipper while computing range") + } + + // Blackening the smaller tree avoids balancing problems on union; + // this can't be done later, though, or it would change the result of compareDepth + val blkNewLeft = blacken(newLeft) + val blkNewRight = blacken(newRight) + val (zipper, levelled, leftMost, smallerDepth) = compareDepth(blkNewLeft, blkNewRight) + + if (levelled) { + BlackTree(tree.key, tree.value, blkNewLeft, blkNewRight) + } else { + val zipFrom = findDepth(zipper, smallerDepth) + val union = if (leftMost) { + RedTree(tree.key, tree.value, blkNewLeft, zipFrom.head) + } else { + RedTree(tree.key, tree.value, zipFrom.head, blkNewRight) + } + val zippedTree = zipFrom.tail.foldLeft(union: Tree[A, B]) { (tree, node) => + if (leftMost) + balanceLeft(isBlackTree(node), node.key, node.value, tree, node.right) + else + balanceRight(isBlackTree(node), node.key, node.value, node.left, tree) + } + zippedTree + } + } + + /* + * Forcing direct fields access using the @inline annotation helps speed up + * various operations (especially smallest/greatest and update/delete). + * + * Unfortunately the direct field access is not guaranteed to work (but + * works on the current implementation of the Scala compiler). + * + * An alternative is to implement the these classes using plain old Java code... + */ + sealed abstract class Tree[A, +B]( + @(inline @getter) final val key: A, + @(inline @getter) final val value: B, + @(inline @getter) final val left: Tree[A, B], + @(inline @getter) final val right: Tree[A, B]) + extends Serializable { + final val count: Int = 1 + RedBlackTree.count(left) + RedBlackTree.count(right) + def black: Tree[A, B] + def red: Tree[A, B] + } + final class RedTree[A, +B](key: A, + value: B, + left: Tree[A, B], + right: Tree[A, B]) extends Tree[A, B](key, value, left, right) { + override def black: Tree[A, B] = BlackTree(key, value, left, right) + override def red: Tree[A, B] = this + override def toString: String = "RedTree(" + key + ", " + value + ", " + left + ", " + right + ")" + } + final class BlackTree[A, +B](key: A, + value: B, + left: Tree[A, B], + right: Tree[A, B]) extends Tree[A, B](key, value, left, right) { + override def black: Tree[A, B] = this + override def red: Tree[A, B] = RedTree(key, value, left, right) + override def toString: String = "BlackTree(" + key + ", " + value + ", " + left + ", " + right + ")" + } + + object RedTree { + @inline def apply[A, B](key: A, value: B, left: Tree[A, B], right: Tree[A, B]) = new RedTree(key, value, left, right) + def unapply[A, B](t: RedTree[A, B]) = Some((t.key, t.value, t.left, t.right)) + } + object BlackTree { + @inline def apply[A, B](key: A, value: B, left: Tree[A, B], right: Tree[A, B]) = new BlackTree(key, value, left, right) + def unapply[A, B](t: BlackTree[A, B]) = Some((t.key, t.value, t.left, t.right)) + } + + private[this] abstract class TreeIterator[A, B, R](tree: Tree[A, B]) extends Iterator[R] { + protected[this] def nextResult(tree: Tree[A, B]): R + + override def hasNext: Boolean = next ne null + + override def next: R = next match { + case null => + throw new NoSuchElementException("next on empty iterator") + case tree => + next = findNext(tree.right) + nextResult(tree) + } + + @tailrec + private[this] def findNext(tree: Tree[A, B]): Tree[A, B] = { + if (tree eq null) popPath() + else if (tree.left eq null) tree + else { + pushPath(tree) + findNext(tree.left) + } + } + + private[this] def pushPath(tree: Tree[A, B]) { + try { + path(index) = tree + index += 1 + } catch { + case _: ArrayIndexOutOfBoundsException => + /* + * Either the tree became unbalanced or we calculated the maximum height incorrectly. + * To avoid crashing the iterator we expand the path array. Obviously this should never + * happen... + * + * An exception handler is used instead of an if-condition to optimize the normal path. + * This makes a large difference in iteration speed! + */ + assert(index >= path.length) + path :+= null + pushPath(tree) + } + } + private[this] def popPath(): Tree[A, B] = if (index == 0) null else { + index -= 1 + path(index) + } + + private[this] var path = if (tree eq null) null else { + /* + * According to "Ralf Hinze. Constructing red-black trees" [http://www.cs.ox.ac.uk/ralf.hinze/publications/#P5] + * the maximum height of a red-black tree is 2*log_2(n + 2) - 2. + * + * According to {@see Integer#numberOfLeadingZeros} ceil(log_2(n)) = (32 - Integer.numberOfLeadingZeros(n - 1)) + * + * We also don't store the deepest nodes in the path so the maximum path length is further reduced by one. + */ + val maximumHeight = 2 * (32 - Integer.numberOfLeadingZeros(tree.count + 2 - 1)) - 2 - 1 + new Array[Tree[A, B]](maximumHeight) + } + private[this] var index = 0 + private[this] var next: Tree[A, B] = findNext(tree) + } + + private[this] class EntriesIterator[A, B](tree: Tree[A, B]) extends TreeIterator[A, B, (A, B)](tree) { + override def nextResult(tree: Tree[A, B]) = (tree.key, tree.value) + } + + private[this] class KeysIterator[A, B](tree: Tree[A, B]) extends TreeIterator[A, B, A](tree) { + override def nextResult(tree: Tree[A, B]) = tree.key + } + + private[this] class ValuesIterator[A, B](tree: Tree[A, B]) extends TreeIterator[A, B, B](tree) { + override def nextResult(tree: Tree[A, B]) = tree.value + } +} diff --git a/src/library/scala/collection/immutable/TreeMap.scala b/src/library/scala/collection/immutable/TreeMap.scala index 50244ef21d..196c3a9d9d 100644 --- a/src/library/scala/collection/immutable/TreeMap.scala +++ b/src/library/scala/collection/immutable/TreeMap.scala @@ -12,6 +12,7 @@ package scala.collection package immutable import generic._ +import immutable.{RedBlackTree => RB} import mutable.Builder import annotation.bridge @@ -45,14 +46,12 @@ object TreeMap extends ImmutableSortedMapFactory[TreeMap] { * @define mayNotTerminateInf * @define willNotTerminateInf */ -class TreeMap[A, +B] private (tree: RedBlack.Node[A, B])(implicit val ordering: Ordering[A]) +class TreeMap[A, +B] private (tree: RB.Tree[A, B])(implicit val ordering: Ordering[A]) extends SortedMap[A, B] with SortedMapLike[A, B, TreeMap[A, B]] with MapLike[A, B, TreeMap[A, B]] with Serializable { - import immutable.{RedBlack => RB} - @deprecated("use `ordering.lt` instead", "2.10") def isSmaller(x: A, y: A) = ordering.lt(x, y) diff --git a/src/library/scala/collection/immutable/TreeSet.scala b/src/library/scala/collection/immutable/TreeSet.scala index 899ef0e5eb..12e2197732 100644 --- a/src/library/scala/collection/immutable/TreeSet.scala +++ b/src/library/scala/collection/immutable/TreeSet.scala @@ -12,6 +12,7 @@ package scala.collection package immutable import generic._ +import immutable.{RedBlackTree => RB} import mutable.{ Builder, SetBuilder } /** $factoryInfo @@ -47,11 +48,9 @@ object TreeSet extends ImmutableSortedSetFactory[TreeSet] { * @define willNotTerminateInf */ @SerialVersionUID(-5685982407650748405L) -class TreeSet[A] private (tree: RedBlack.Node[A, Unit])(implicit val ordering: Ordering[A]) +class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: Ordering[A]) extends SortedSet[A] with SortedSetLike[A, TreeSet[A]] with Serializable { - import immutable.{RedBlack => RB} - override def stringPrefix = "TreeSet" override def size = RB.count(tree) @@ -105,7 +104,7 @@ class TreeSet[A] private (tree: RedBlack.Node[A, Unit])(implicit val ordering: O def this()(implicit ordering: Ordering[A]) = this(null)(ordering) - private def newSet(t: RedBlack.Node[A, Unit]) = new TreeSet[A](t) + private def newSet(t: RB.Tree[A, Unit]) = new TreeSet[A](t) /** A factory to create empty sets of the same type of keys. */ diff --git a/test/files/scalacheck/redblack.scala b/test/files/scalacheck/redblack.scala index 83d3ca0c1f..bbc6504f58 100644 --- a/test/files/scalacheck/redblack.scala +++ b/test/files/scalacheck/redblack.scala @@ -1,4 +1,3 @@ -import collection.immutable._ import org.scalacheck._ import Prop._ import Gen._ @@ -15,23 +14,26 @@ Both children of every red node are black. Every simple path from a given node to any of its descendant leaves contains the same number of black nodes. */ -package scala.collection.immutable { abstract class RedBlackTest extends Properties("RedBlack") { def minimumSize = 0 def maximumSize = 5 - import RedBlack._ + object RedBlackTest extends scala.collection.immutable.RedBlack[String] { + def isSmaller(x: String, y: String) = x < y + } + + import RedBlackTest._ - def nodeAt[A](tree: Node[String, A], n: Int): Option[(String, A)] = if (n < iterator(tree).size && n >= 0) - Some(iterator(tree).drop(n).next) + def nodeAt[A](tree: Tree[A], n: Int): Option[(String, A)] = if (n < tree.iterator.size && n >= 0) + Some(tree.iterator.drop(n).next) else None - def treeContains[A](tree: Node[String, A], key: String) = iterator(tree).map(_._1) contains key + def treeContains[A](tree: Tree[A], key: String) = tree.iterator.map(_._1) contains key - def mkTree(level: Int, parentIsBlack: Boolean = false, label: String = ""): Gen[Node[String, Int]] = + def mkTree(level: Int, parentIsBlack: Boolean = false, label: String = ""): Gen[Tree[Int]] = if (level == 0) { - value(null) + value(Empty) } else { for { oddOrEven <- choose(0, 2) @@ -42,9 +44,9 @@ abstract class RedBlackTest extends Properties("RedBlack") { right <- mkTree(nextLevel, !isRed, label + "R") } yield { if (isRed) - RedNode(label + "N", 0, left, right) + RedTree(label + "N", 0, left, right) else - BlackNode(label + "N", 0, left, right) + BlackTree(label + "N", 0, left, right) } } @@ -54,10 +56,10 @@ abstract class RedBlackTest extends Properties("RedBlack") { } yield tree type ModifyParm - def genParm(tree: Node[String, Int]): Gen[ModifyParm] - def modify(tree: Node[String, Int], parm: ModifyParm): Node[String, Int] + def genParm(tree: Tree[Int]): Gen[ModifyParm] + def modify(tree: Tree[Int], parm: ModifyParm): Tree[Int] - def genInput: Gen[(Node[String, Int], ModifyParm, Node[String, Int])] = for { + def genInput: Gen[(Tree[Int], ModifyParm, Tree[Int])] = for { tree <- genTree parm <- genParm(tree) } yield (tree, parm, modify(tree, parm)) @@ -66,30 +68,30 @@ abstract class RedBlackTest extends Properties("RedBlack") { trait RedBlackInvariants { self: RedBlackTest => - import RedBlack._ + import RedBlackTest._ - def rootIsBlack[A](t: Node[String, A]) = isBlack(t) + def rootIsBlack[A](t: Tree[A]) = t.isBlack - def areAllLeavesBlack[A](t: Node[String, A]): Boolean = t match { - case null => isBlack(t) - case ne => List(ne.left, ne.right) forall areAllLeavesBlack + def areAllLeavesBlack[A](t: Tree[A]): Boolean = t match { + case Empty => t.isBlack + case ne: NonEmpty[_] => List(ne.left, ne.right) forall areAllLeavesBlack } - def areRedNodeChildrenBlack[A](t: Node[String, A]): Boolean = t match { - case RedNode(_, _, left, right) => List(left, right) forall (t => isBlack(t) && areRedNodeChildrenBlack(t)) - case BlackNode(_, _, left, right) => List(left, right) forall areRedNodeChildrenBlack - case null => true + def areRedNodeChildrenBlack[A](t: Tree[A]): Boolean = t match { + case RedTree(_, _, left, right) => List(left, right) forall (t => t.isBlack && areRedNodeChildrenBlack(t)) + case BlackTree(_, _, left, right) => List(left, right) forall areRedNodeChildrenBlack + case Empty => true } - def blackNodesToLeaves[A](t: Node[String, A]): List[Int] = t match { - case null => List(1) - case BlackNode(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves map (_ + 1) - case RedNode(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves + def blackNodesToLeaves[A](t: Tree[A]): List[Int] = t match { + case Empty => List(1) + case BlackTree(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves map (_ + 1) + case RedTree(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves } - def areBlackNodesToLeavesEqual[A](t: Node[String, A]): Boolean = t match { - case null => true - case ne => + def areBlackNodesToLeavesEqual[A](t: Tree[A]): Boolean = t match { + case Empty => true + case ne: NonEmpty[_] => ( blackNodesToLeaves(ne).distinct.size == 1 && areBlackNodesToLeavesEqual(ne.left) @@ -97,10 +99,10 @@ trait RedBlackInvariants { ) } - def orderIsPreserved[A](t: Node[String, A]): Boolean = - iterator(t) zip iterator(t).drop(1) forall { case (x, y) => x._1 < y._1 } + def orderIsPreserved[A](t: Tree[A]): Boolean = + t.iterator zip t.iterator.drop(1) forall { case (x, y) => isSmaller(x._1, y._1) } - def setup(invariant: Node[String, Int] => Boolean) = forAll(genInput) { case (tree, parm, newTree) => + def setup(invariant: Tree[Int] => Boolean) = forAll(genInput) { case (tree, parm, newTree) => invariant(newTree) } @@ -112,13 +114,13 @@ trait RedBlackInvariants { } object TestInsert extends RedBlackTest with RedBlackInvariants { - import RedBlack._ + import RedBlackTest._ override type ModifyParm = Int - override def genParm(tree: Node[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size + 1) - override def modify(tree: Node[String, Int], parm: ModifyParm): Node[String, Int] = update(tree, generateKey(tree, parm), 0) + override def genParm(tree: Tree[Int]): Gen[ModifyParm] = choose(0, tree.iterator.size + 1) + override def modify(tree: Tree[Int], parm: ModifyParm): Tree[Int] = tree update (generateKey(tree, parm), 0) - def generateKey(tree: Node[String, Int], parm: ModifyParm): String = nodeAt(tree, parm) match { + def generateKey(tree: Tree[Int], parm: ModifyParm): String = nodeAt(tree, parm) match { case Some((key, _)) => key.init.mkString + "MN" case None => nodeAt(tree, parm - 1) match { case Some((key, _)) => key.init.mkString + "RN" @@ -132,31 +134,31 @@ object TestInsert extends RedBlackTest with RedBlackInvariants { } object TestModify extends RedBlackTest { - import RedBlack._ + import RedBlackTest._ def newValue = 1 override def minimumSize = 1 override type ModifyParm = Int - override def genParm(tree: Node[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size) - override def modify(tree: Node[String, Int], parm: ModifyParm): Node[String, Int] = nodeAt(tree, parm) map { - case (key, _) => update(tree, key, newValue) + override def genParm(tree: Tree[Int]): Gen[ModifyParm] = choose(0, tree.iterator.size) + override def modify(tree: Tree[Int], parm: ModifyParm): Tree[Int] = nodeAt(tree, parm) map { + case (key, _) => tree update (key, newValue) } getOrElse tree property("update modifies values") = forAll(genInput) { case (tree, parm, newTree) => nodeAt(tree,parm) forall { case (key, _) => - iterator(newTree) contains (key, newValue) + newTree.iterator contains (key, newValue) } } } object TestDelete extends RedBlackTest with RedBlackInvariants { - import RedBlack._ + import RedBlackTest._ override def minimumSize = 1 override type ModifyParm = Int - override def genParm(tree: Node[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size) - override def modify(tree: Node[String, Int], parm: ModifyParm): Node[String, Int] = nodeAt(tree, parm) map { - case (key, _) => delete(tree, key) + override def genParm(tree: Tree[Int]): Gen[ModifyParm] = choose(0, tree.iterator.size) + override def modify(tree: Tree[Int], parm: ModifyParm): Tree[Int] = nodeAt(tree, parm) map { + case (key, _) => tree delete key } getOrElse tree property("delete removes elements") = forAll(genInput) { case (tree, parm, newTree) => @@ -167,41 +169,40 @@ object TestDelete extends RedBlackTest with RedBlackInvariants { } object TestRange extends RedBlackTest with RedBlackInvariants { - import RedBlack._ + import RedBlackTest._ override type ModifyParm = (Option[Int], Option[Int]) - override def genParm(tree: Node[String, Int]): Gen[ModifyParm] = for { - from <- choose(0, iterator(tree).size) - to <- choose(0, iterator(tree).size) suchThat (from <=) + override def genParm(tree: Tree[Int]): Gen[ModifyParm] = for { + from <- choose(0, tree.iterator.size) + to <- choose(0, tree.iterator.size) suchThat (from <=) optionalFrom <- oneOf(Some(from), None, Some(from)) // Double Some(n) to get around a bug optionalTo <- oneOf(Some(to), None, Some(to)) // Double Some(n) to get around a bug } yield (optionalFrom, optionalTo) - override def modify(tree: Node[String, Int], parm: ModifyParm): Node[String, Int] = { + override def modify(tree: Tree[Int], parm: ModifyParm): Tree[Int] = { val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) - range(tree, from, to) + tree range (from, to) } property("range boundaries respected") = forAll(genInput) { case (tree, parm, newTree) => val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) - ("lower boundary" |: (from forall ( key => iterator(newTree).map(_._1) forall (key <=)))) && - ("upper boundary" |: (to forall ( key => iterator(newTree).map(_._1) forall (key >)))) + ("lower boundary" |: (from forall ( key => newTree.iterator.map(_._1) forall (key <=)))) && + ("upper boundary" |: (to forall ( key => newTree.iterator.map(_._1) forall (key >)))) } property("range returns all elements") = forAll(genInput) { case (tree, parm, newTree) => val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) - val filteredTree = (iterator(tree) + val filteredTree = (tree.iterator .map(_._1) .filter(key => from forall (key >=)) .filter(key => to forall (key <)) .toList) - filteredTree == iterator(newTree).map(_._1).toList + filteredTree == newTree.iterator.map(_._1).toList } } -} object Test extends Properties("RedBlack") { include(TestInsert) diff --git a/test/files/scalacheck/redblacktree.scala b/test/files/scalacheck/redblacktree.scala new file mode 100644 index 0000000000..10f3f0fbbf --- /dev/null +++ b/test/files/scalacheck/redblacktree.scala @@ -0,0 +1,212 @@ +import collection.immutable.{RedBlackTree => RB} +import org.scalacheck._ +import Prop._ +import Gen._ + +/* +Properties of a Red & Black Tree: + +A node is either red or black. +The root is black. (This rule is used in some definitions and not others. Since the +root can always be changed from red to black but not necessarily vice-versa this +rule has little effect on analysis.) +All leaves are black. +Both children of every red node are black. +Every simple path from a given node to any of its descendant leaves contains the same number of black nodes. +*/ + +package scala.collection.immutable.redblacktree { + abstract class RedBlackTreeTest extends Properties("RedBlackTree") { + def minimumSize = 0 + def maximumSize = 5 + + import RB._ + + def nodeAt[A](tree: Tree[String, A], n: Int): Option[(String, A)] = if (n < iterator(tree).size && n >= 0) + Some(iterator(tree).drop(n).next) + else + None + + def treeContains[A](tree: Tree[String, A], key: String) = iterator(tree).map(_._1) contains key + + def mkTree(level: Int, parentIsBlack: Boolean = false, label: String = ""): Gen[Tree[String, Int]] = + if (level == 0) { + value(null) + } else { + for { + oddOrEven <- choose(0, 2) + tryRed = oddOrEven.sample.get % 2 == 0 // work around arbitrary[Boolean] bug + isRed = parentIsBlack && tryRed + nextLevel = if (isRed) level else level - 1 + left <- mkTree(nextLevel, !isRed, label + "L") + right <- mkTree(nextLevel, !isRed, label + "R") + } yield { + if (isRed) + RedTree(label + "N", 0, left, right) + else + BlackTree(label + "N", 0, left, right) + } + } + + def genTree = for { + depth <- choose(minimumSize, maximumSize + 1) + tree <- mkTree(depth) + } yield tree + + type ModifyParm + def genParm(tree: Tree[String, Int]): Gen[ModifyParm] + def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] + + def genInput: Gen[(Tree[String, Int], ModifyParm, Tree[String, Int])] = for { + tree <- genTree + parm <- genParm(tree) + } yield (tree, parm, modify(tree, parm)) + } + + trait RedBlackTreeInvariants { + self: RedBlackTreeTest => + + import RB._ + + def rootIsBlack[A](t: Tree[String, A]) = isBlack(t) + + def areAllLeavesBlack[A](t: Tree[String, A]): Boolean = t match { + case null => isBlack(t) + case ne => List(ne.left, ne.right) forall areAllLeavesBlack + } + + def areRedNodeChildrenBlack[A](t: Tree[String, A]): Boolean = t match { + case RedTree(_, _, left, right) => List(left, right) forall (t => isBlack(t) && areRedNodeChildrenBlack(t)) + case BlackTree(_, _, left, right) => List(left, right) forall areRedNodeChildrenBlack + case null => true + } + + def blackNodesToLeaves[A](t: Tree[String, A]): List[Int] = t match { + case null => List(1) + case BlackTree(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves map (_ + 1) + case RedTree(_, _, left, right) => List(left, right) flatMap blackNodesToLeaves + } + + def areBlackNodesToLeavesEqual[A](t: Tree[String, A]): Boolean = t match { + case null => true + case ne => + ( + blackNodesToLeaves(ne).distinct.size == 1 + && areBlackNodesToLeavesEqual(ne.left) + && areBlackNodesToLeavesEqual(ne.right) + ) + } + + def orderIsPreserved[A](t: Tree[String, A]): Boolean = + iterator(t) zip iterator(t).drop(1) forall { case (x, y) => x._1 < y._1 } + + def setup(invariant: Tree[String, Int] => Boolean) = forAll(genInput) { case (tree, parm, newTree) => + invariant(newTree) + } + + property("root is black") = setup(rootIsBlack) + property("all leaves are black") = setup(areAllLeavesBlack) + property("children of red nodes are black") = setup(areRedNodeChildrenBlack) + property("black nodes are balanced") = setup(areBlackNodesToLeavesEqual) + property("ordering of keys is preserved") = setup(orderIsPreserved) + } + + object TestInsert extends RedBlackTreeTest with RedBlackTreeInvariants { + import RB._ + + override type ModifyParm = Int + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size + 1) + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = update(tree, generateKey(tree, parm), 0) + + def generateKey(tree: Tree[String, Int], parm: ModifyParm): String = nodeAt(tree, parm) match { + case Some((key, _)) => key.init.mkString + "MN" + case None => nodeAt(tree, parm - 1) match { + case Some((key, _)) => key.init.mkString + "RN" + case None => "N" + } + } + + property("update adds elements") = forAll(genInput) { case (tree, parm, newTree) => + treeContains(newTree, generateKey(tree, parm)) + } + } + + object TestModify extends RedBlackTreeTest { + import RB._ + + def newValue = 1 + override def minimumSize = 1 + override type ModifyParm = Int + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size) + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map { + case (key, _) => update(tree, key, newValue) + } getOrElse tree + + property("update modifies values") = forAll(genInput) { case (tree, parm, newTree) => + nodeAt(tree,parm) forall { case (key, _) => + iterator(newTree) contains (key, newValue) + } + } + } + + object TestDelete extends RedBlackTreeTest with RedBlackTreeInvariants { + import RB._ + + override def minimumSize = 1 + override type ModifyParm = Int + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = choose(0, iterator(tree).size) + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = nodeAt(tree, parm) map { + case (key, _) => delete(tree, key) + } getOrElse tree + + property("delete removes elements") = forAll(genInput) { case (tree, parm, newTree) => + nodeAt(tree, parm) forall { case (key, _) => + !treeContains(newTree, key) + } + } + } + + object TestRange extends RedBlackTreeTest with RedBlackTreeInvariants { + import RB._ + + override type ModifyParm = (Option[Int], Option[Int]) + override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = for { + from <- choose(0, iterator(tree).size) + to <- choose(0, iterator(tree).size) suchThat (from <=) + optionalFrom <- oneOf(Some(from), None, Some(from)) // Double Some(n) to get around a bug + optionalTo <- oneOf(Some(to), None, Some(to)) // Double Some(n) to get around a bug + } yield (optionalFrom, optionalTo) + + override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = { + val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) + val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) + range(tree, from, to) + } + + property("range boundaries respected") = forAll(genInput) { case (tree, parm, newTree) => + val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) + val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) + ("lower boundary" |: (from forall ( key => iterator(newTree).map(_._1) forall (key <=)))) && + ("upper boundary" |: (to forall ( key => iterator(newTree).map(_._1) forall (key >)))) + } + + property("range returns all elements") = forAll(genInput) { case (tree, parm, newTree) => + val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) + val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) + val filteredTree = (iterator(tree) + .map(_._1) + .filter(key => from forall (key >=)) + .filter(key => to forall (key <)) + .toList) + filteredTree == iterator(newTree).map(_._1).toList + } + } +} + +object Test extends Properties("RedBlackTree") { + import collection.immutable.redblacktree._ + include(TestInsert) + include(TestModify) + include(TestDelete) + include(TestRange) +} -- cgit v1.2.3 From e61075c4e173d8fad5127e90046f5b91e97c3180 Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Sat, 7 Jan 2012 19:20:46 +0100 Subject: Tests for takeWhile/dropWhile/span. Also simplified implementation of span to just use splitAt. --- src/library/scala/collection/immutable/TreeMap.scala | 5 +---- src/library/scala/collection/immutable/TreeSet.scala | 5 +---- test/files/scalacheck/treemap.scala | 15 +++++++++++++++ test/files/scalacheck/treeset.scala | 15 +++++++++++++++ 4 files changed, 32 insertions(+), 8 deletions(-) (limited to 'test/files/scalacheck') diff --git a/src/library/scala/collection/immutable/TreeMap.scala b/src/library/scala/collection/immutable/TreeMap.scala index 196c3a9d9d..2bb8a566c6 100644 --- a/src/library/scala/collection/immutable/TreeMap.scala +++ b/src/library/scala/collection/immutable/TreeMap.scala @@ -116,10 +116,7 @@ class TreeMap[A, +B] private (tree: RB.Tree[A, B])(implicit val ordering: Orderi } override def dropWhile(p: ((A, B)) => Boolean) = drop(countWhile(p)) override def takeWhile(p: ((A, B)) => Boolean) = take(countWhile(p)) - override def span(p: ((A, B)) => Boolean) = { - val n = countWhile(p) - (take(n), drop(n)) - } + override def span(p: ((A, B)) => Boolean) = splitAt(countWhile(p)) /** A factory to create empty maps of the same type of keys. */ diff --git a/src/library/scala/collection/immutable/TreeSet.scala b/src/library/scala/collection/immutable/TreeSet.scala index 12e2197732..8b95358d1c 100644 --- a/src/library/scala/collection/immutable/TreeSet.scala +++ b/src/library/scala/collection/immutable/TreeSet.scala @@ -94,10 +94,7 @@ class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: Orderin } override def dropWhile(p: A => Boolean) = drop(countWhile(p)) override def takeWhile(p: A => Boolean) = take(countWhile(p)) - override def span(p: A => Boolean) = { - val n = countWhile(p) - (take(n), drop(n)) - } + override def span(p: A => Boolean) = splitAt(countWhile(p)) @deprecated("use `ordering.lt` instead", "2.10") def isSmaller(x: A, y: A) = compare(x,y) < 0 diff --git a/test/files/scalacheck/treemap.scala b/test/files/scalacheck/treemap.scala index 9970bb01aa..7d5f94d58b 100644 --- a/test/files/scalacheck/treemap.scala +++ b/test/files/scalacheck/treemap.scala @@ -96,6 +96,21 @@ object Test extends Properties("TreeMap") { prefix == subject.take(n) && suffix == subject.drop(n) } + property("takeWhile") = forAll { (subject: TreeMap[Int, String]) => + val result = subject.takeWhile(_._1 < 0) + result.forall(_._1 < 0) && result == subject.take(result.size) + } + + property("dropWhile") = forAll { (subject: TreeMap[Int, String]) => + val result = subject.dropWhile(_._1 < 0) + result.forall(_._1 >= 0) && result == subject.takeRight(result.size) + } + + property("span identity") = forAll { (subject: TreeMap[Int, String]) => + val (prefix, suffix) = subject.span(_._1 < 0) + prefix.forall(_._1 < 0) && suffix.forall(_._1 >= 0) && subject == prefix ++ suffix + } + property("remove single") = forAll { (subject: TreeMap[Int, String]) => subject.nonEmpty ==> { val key = oneOf(subject.keys.toSeq).sample.get val removed = subject - key diff --git a/test/files/scalacheck/treeset.scala b/test/files/scalacheck/treeset.scala index 87c3eb7108..e47a1b6cdd 100644 --- a/test/files/scalacheck/treeset.scala +++ b/test/files/scalacheck/treeset.scala @@ -92,6 +92,21 @@ object Test extends Properties("TreeSet") { prefix == subject.take(n) && suffix == subject.drop(n) } + property("takeWhile") = forAll { (subject: TreeMap[Int, String]) => + val result = subject.takeWhile(_._1 < 0) + result.forall(_._1 < 0) && result == subject.take(result.size) + } + + property("dropWhile") = forAll { (subject: TreeMap[Int, String]) => + val result = subject.dropWhile(_._1 < 0) + result.forall(_._1 >= 0) && result == subject.takeRight(result.size) + } + + property("span identity") = forAll { (subject: TreeMap[Int, String]) => + val (prefix, suffix) = subject.span(_._1 < 0) + prefix.forall(_._1 < 0) && suffix.forall(_._1 >= 0) && subject == prefix ++ suffix + } + property("remove single") = forAll { (subject: TreeSet[Int]) => subject.nonEmpty ==> { val element = oneOf(subject.toSeq).sample.get val removed = subject - element -- cgit v1.2.3 From 8b3f984d4e2e444c0712a7457aefd159d4024b1f Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Sat, 7 Jan 2012 23:31:06 +0100 Subject: Fix silly copy-paste error. --- test/files/scalacheck/treeset.scala | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'test/files/scalacheck') diff --git a/test/files/scalacheck/treeset.scala b/test/files/scalacheck/treeset.scala index e47a1b6cdd..7f99aec77e 100644 --- a/test/files/scalacheck/treeset.scala +++ b/test/files/scalacheck/treeset.scala @@ -92,19 +92,19 @@ object Test extends Properties("TreeSet") { prefix == subject.take(n) && suffix == subject.drop(n) } - property("takeWhile") = forAll { (subject: TreeMap[Int, String]) => - val result = subject.takeWhile(_._1 < 0) - result.forall(_._1 < 0) && result == subject.take(result.size) + property("takeWhile") = forAll { (subject: TreeSet[Int]) => + val result = subject.takeWhile(_ < 0) + result.forall(_ < 0) && result == subject.take(result.size) } - property("dropWhile") = forAll { (subject: TreeMap[Int, String]) => - val result = subject.dropWhile(_._1 < 0) - result.forall(_._1 >= 0) && result == subject.takeRight(result.size) + property("dropWhile") = forAll { (subject: TreeSet[Int]) => + val result = subject.dropWhile(_ < 0) + result.forall(_ >= 0) && result == subject.takeRight(result.size) } - property("span identity") = forAll { (subject: TreeMap[Int, String]) => - val (prefix, suffix) = subject.span(_._1 < 0) - prefix.forall(_._1 < 0) && suffix.forall(_._1 >= 0) && subject == prefix ++ suffix + property("span identity") = forAll { (subject: TreeSet[Int]) => + val (prefix, suffix) = subject.span(_ < 0) + prefix.forall(_ < 0) && suffix.forall(_ >= 0) && subject == prefix ++ suffix } property("remove single") = forAll { (subject: TreeSet[Int]) => subject.nonEmpty ==> { -- cgit v1.2.3 From f26f610278887b842de3a4e4fdafb866dd1afb62 Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Sun, 8 Jan 2012 12:59:45 +0100 Subject: Test for maximum height of red-black tree. --- test/files/scalacheck/redblacktree.scala | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'test/files/scalacheck') diff --git a/test/files/scalacheck/redblacktree.scala b/test/files/scalacheck/redblacktree.scala index 10f3f0fbbf..34fa8eae8d 100644 --- a/test/files/scalacheck/redblacktree.scala +++ b/test/files/scalacheck/redblacktree.scala @@ -29,6 +29,8 @@ package scala.collection.immutable.redblacktree { def treeContains[A](tree: Tree[String, A], key: String) = iterator(tree).map(_._1) contains key + def height(tree: Tree[_, _]): Int = if (tree eq null) 0 else (1 + math.max(height(tree.left), height(tree.right))) + def mkTree(level: Int, parentIsBlack: Boolean = false, label: String = ""): Gen[Tree[String, Int]] = if (level == 0) { value(null) @@ -100,6 +102,8 @@ package scala.collection.immutable.redblacktree { def orderIsPreserved[A](t: Tree[String, A]): Boolean = iterator(t) zip iterator(t).drop(1) forall { case (x, y) => x._1 < y._1 } + def heightIsBounded(t: Tree[_, _]): Boolean = height(t) <= (2 * (32 - Integer.numberOfLeadingZeros(count(t) + 2)) - 2) + def setup(invariant: Tree[String, Int] => Boolean) = forAll(genInput) { case (tree, parm, newTree) => invariant(newTree) } @@ -109,6 +113,7 @@ package scala.collection.immutable.redblacktree { property("children of red nodes are black") = setup(areRedNodeChildrenBlack) property("black nodes are balanced") = setup(areBlackNodesToLeavesEqual) property("ordering of keys is preserved") = setup(orderIsPreserved) + property("height is bounded") = setup(heightIsBounded) } object TestInsert extends RedBlackTreeTest with RedBlackTreeInvariants { -- cgit v1.2.3 From 00b5cb84df493aace270674054d2f6ddf3721131 Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Sun, 15 Jan 2012 13:48:00 +0100 Subject: Optimized implementation of TreeMap/TreeSet#to method. Performance of `to` and `until` is now the same. --- .../scala/collection/immutable/RedBlackTree.scala | 18 ++++++++----- .../scala/collection/immutable/TreeMap.scala | 10 ++++--- .../scala/collection/immutable/TreeSet.scala | 6 ++++- test/files/scalacheck/redblacktree.scala | 31 +++++++++++++--------- test/files/scalacheck/treemap.scala | 18 +++++++++++++ test/files/scalacheck/treeset.scala | 18 +++++++++++++ 6 files changed, 77 insertions(+), 24 deletions(-) (limited to 'test/files/scalacheck') diff --git a/src/library/scala/collection/immutable/RedBlackTree.scala b/src/library/scala/collection/immutable/RedBlackTree.scala index ebd88ce3fe..d8caeab096 100644 --- a/src/library/scala/collection/immutable/RedBlackTree.scala +++ b/src/library/scala/collection/immutable/RedBlackTree.scala @@ -45,7 +45,11 @@ object RedBlackTree { def count(tree: Tree[_, _]) = if (tree eq null) 0 else tree.count def update[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(tree, k, v)) def delete[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = blacken(del(tree, k)) - def range[A, B](tree: Tree[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] = blacken(rng(tree, from, until)) + def range[A, B](tree: Tree[A, B], low: Option[A], lowInclusive: Boolean, high: Option[A], highInclusive: Boolean)(implicit ordering: Ordering[A]): Tree[A, B] = { + val after: Option[A => Boolean] = low.map(key => if (lowInclusive) ordering.lt(_, key) else ordering.lteq(_, key)) + val before: Option[A => Boolean] = high.map(key => if (highInclusive) ordering.lt(key, _) else ordering.lteq(key, _)) + blacken(rng(tree, after, before)) + } def smallest[A, B](tree: Tree[A, B]): Tree[A, B] = { if (tree eq null) throw new NoSuchElementException("empty map") @@ -198,13 +202,13 @@ object RedBlackTree { else append(tree.left, tree.right) } - private[this] def rng[A, B](tree: Tree[A, B], from: Option[A], until: Option[A])(implicit ordering: Ordering[A]): Tree[A, B] = { + private[this] def rng[A, B](tree: Tree[A, B], after: Option[A => Boolean], before: Option[A => Boolean])(implicit ordering: Ordering[A]): Tree[A, B] = { if (tree eq null) return null - if (from == None && until == None) return tree - if (from != None && ordering.lt(tree.key, from.get)) return rng(tree.right, from, until); - if (until != None && ordering.lteq(until.get, tree.key)) return rng(tree.left, from, until); - val newLeft = rng(tree.left, from, None) - val newRight = rng(tree.right, None, until) + if (after == None && before == None) return tree + if (after != None && after.get(tree.key)) return rng(tree.right, after, before); + if (before != None && before.get(tree.key)) return rng(tree.left, after, before); + val newLeft = rng(tree.left, after, None) + val newRight = rng(tree.right, None, before) if ((newLeft eq tree.left) && (newRight eq tree.right)) tree else if (newLeft eq null) upd(newRight, tree.key, tree.value); else if (newRight eq null) upd(newLeft, tree.key, tree.value); diff --git a/src/library/scala/collection/immutable/TreeMap.scala b/src/library/scala/collection/immutable/TreeMap.scala index 2bb8a566c6..3eba64dca3 100644 --- a/src/library/scala/collection/immutable/TreeMap.scala +++ b/src/library/scala/collection/immutable/TreeMap.scala @@ -62,9 +62,13 @@ class TreeMap[A, +B] private (tree: RB.Tree[A, B])(implicit val ordering: Orderi def this()(implicit ordering: Ordering[A]) = this(null)(ordering) - override def rangeImpl(from : Option[A], until : Option[A]): TreeMap[A,B] = { - val ntree = RB.range(tree, from,until) - new TreeMap[A,B](ntree) + override def rangeImpl(from: Option[A], until: Option[A]): TreeMap[A, B] = { + val ntree = RB.range(tree, from, true, until, false) + new TreeMap[A, B](ntree) + } + override def to(to: A): TreeMap[A, B] = { + val ntree = RB.range(tree, None, true, Some(to), true) + new TreeMap[A, B](ntree) } override def firstKey = RB.smallest(tree).key diff --git a/src/library/scala/collection/immutable/TreeSet.scala b/src/library/scala/collection/immutable/TreeSet.scala index 8b95358d1c..5dd80e87a4 100644 --- a/src/library/scala/collection/immutable/TreeSet.scala +++ b/src/library/scala/collection/immutable/TreeSet.scala @@ -151,7 +151,11 @@ class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: Orderin override def foreach[U](f: A => U) = RB.foreachKey(tree, f) override def rangeImpl(from: Option[A], until: Option[A]): TreeSet[A] = { - val ntree = RB.range(tree, from, until) + val ntree = RB.range(tree, from, true, until, false) + newSet(ntree) + } + override def to(to: A): TreeSet[A] = { + val ntree = RB.range(tree, None, true, Some(to), true) newSet(ntree) } override def firstKey = head diff --git a/test/files/scalacheck/redblacktree.scala b/test/files/scalacheck/redblacktree.scala index 34fa8eae8d..14538c2352 100644 --- a/test/files/scalacheck/redblacktree.scala +++ b/test/files/scalacheck/redblacktree.scala @@ -174,36 +174,41 @@ package scala.collection.immutable.redblacktree { object TestRange extends RedBlackTreeTest with RedBlackTreeInvariants { import RB._ - override type ModifyParm = (Option[Int], Option[Int]) + override type ModifyParm = (Option[Int], Boolean, Option[Int], Boolean) override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = for { from <- choose(0, iterator(tree).size) + fromInclusive <- oneOf(false, true) to <- choose(0, iterator(tree).size) suchThat (from <=) + toInclusive <- oneOf(false, true) optionalFrom <- oneOf(Some(from), None, Some(from)) // Double Some(n) to get around a bug optionalTo <- oneOf(Some(to), None, Some(to)) // Double Some(n) to get around a bug - } yield (optionalFrom, optionalTo) + } yield (optionalFrom, fromInclusive, optionalTo, toInclusive) override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = { val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) - val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) - range(tree, from, to) + val to = parm._3 flatMap (nodeAt(tree, _) map (_._1)) + range(tree, from, parm._2, to, parm._4) } property("range boundaries respected") = forAll(genInput) { case (tree, parm, newTree) => val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) - val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) - ("lower boundary" |: (from forall ( key => iterator(newTree).map(_._1) forall (key <=)))) && - ("upper boundary" |: (to forall ( key => iterator(newTree).map(_._1) forall (key >)))) + val fromPredicate: String => String => Boolean = if (parm._2) (_ <=) else (_ <) + val to = parm._3 flatMap (nodeAt(tree, _) map (_._1)) + val toPredicate: String => String => Boolean = if (parm._4) (_ >=) else (_ >) + ("lower boundary" |: (from forall ( key => keysIterator(newTree) forall fromPredicate(key)))) && + ("upper boundary" |: (to forall ( key => keysIterator(newTree) forall toPredicate(key)))) } property("range returns all elements") = forAll(genInput) { case (tree, parm, newTree) => val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) - val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) - val filteredTree = (iterator(tree) - .map(_._1) - .filter(key => from forall (key >=)) - .filter(key => to forall (key <)) + val fromPredicate: String => String => Boolean = if (parm._2) (_ >=) else (_ >) + val to = parm._3 flatMap (nodeAt(tree, _) map (_._1)) + val toPredicate: String => String => Boolean = if (parm._4) (_ <=) else (_ <) + val filteredTree = (keysIterator(tree) + .filter(key => from forall fromPredicate(key)) + .filter(key => to forall toPredicate(key)) .toList) - filteredTree == iterator(newTree).map(_._1).toList + filteredTree == keysIterator(newTree).toList } } } diff --git a/test/files/scalacheck/treemap.scala b/test/files/scalacheck/treemap.scala index 7d5f94d58b..ba6d117fd4 100644 --- a/test/files/scalacheck/treemap.scala +++ b/test/files/scalacheck/treemap.scala @@ -111,6 +111,24 @@ object Test extends Properties("TreeMap") { prefix.forall(_._1 < 0) && suffix.forall(_._1 >= 0) && subject == prefix ++ suffix } + property("from is inclusive") = forAll { (subject: TreeMap[Int, String]) => subject.nonEmpty ==> { + val n = choose(0, subject.size - 1).sample.get + val from = subject.drop(n).firstKey + subject.from(from).firstKey == from && subject.from(from).forall(_._1 >= from) + }} + + property("to is inclusive") = forAll { (subject: TreeMap[Int, String]) => subject.nonEmpty ==> { + val n = choose(0, subject.size - 1).sample.get + val to = subject.drop(n).firstKey + subject.to(to).lastKey == to && subject.to(to).forall(_._1 <= to) + }} + + property("until is exclusive") = forAll { (subject: TreeMap[Int, String]) => subject.size > 1 ==> { + val n = choose(1, subject.size - 1).sample.get + val until = subject.drop(n).firstKey + subject.until(until).lastKey == subject.take(n).lastKey && subject.until(until).forall(_._1 <= until) + }} + property("remove single") = forAll { (subject: TreeMap[Int, String]) => subject.nonEmpty ==> { val key = oneOf(subject.keys.toSeq).sample.get val removed = subject - key diff --git a/test/files/scalacheck/treeset.scala b/test/files/scalacheck/treeset.scala index 7f99aec77e..e6d1b50860 100644 --- a/test/files/scalacheck/treeset.scala +++ b/test/files/scalacheck/treeset.scala @@ -107,6 +107,24 @@ object Test extends Properties("TreeSet") { prefix.forall(_ < 0) && suffix.forall(_ >= 0) && subject == prefix ++ suffix } + property("from is inclusive") = forAll { (subject: TreeSet[Int]) => subject.nonEmpty ==> { + val n = choose(0, subject.size - 1).sample.get + val from = subject.drop(n).firstKey + subject.from(from).firstKey == from && subject.from(from).forall(_ >= from) + }} + + property("to is inclusive") = forAll { (subject: TreeSet[Int]) => subject.nonEmpty ==> { + val n = choose(0, subject.size - 1).sample.get + val to = subject.drop(n).firstKey + subject.to(to).lastKey == to && subject.to(to).forall(_ <= to) + }} + + property("until is exclusive") = forAll { (subject: TreeSet[Int]) => subject.size > 1 ==> { + val n = choose(1, subject.size - 1).sample.get + val until = subject.drop(n).firstKey + subject.until(until).lastKey == subject.take(n).lastKey && subject.until(until).forall(_ <= until) + }} + property("remove single") = forAll { (subject: TreeSet[Int]) => subject.nonEmpty ==> { val element = oneOf(subject.toSeq).sample.get val removed = subject - element -- cgit v1.2.3 From 7824dbd3cfe6704ab56aa5ceb2af2f5f4e55cbc7 Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Sat, 21 Jan 2012 22:55:59 +0100 Subject: Custom coded version of range/from/to/until. This avoids unnecessary allocation of Option and Function objects, mostly helping performance of small trees. --- .../scala/collection/immutable/RedBlackTree.scala | 48 +++++++++++++++++----- .../scala/collection/immutable/TreeMap.scala | 13 +++--- .../scala/collection/immutable/TreeSet.scala | 14 +++---- test/files/scalacheck/redblacktree.scala | 26 +++++------- 4 files changed, 59 insertions(+), 42 deletions(-) (limited to 'test/files/scalacheck') diff --git a/src/library/scala/collection/immutable/RedBlackTree.scala b/src/library/scala/collection/immutable/RedBlackTree.scala index d8caeab096..7110ca4194 100644 --- a/src/library/scala/collection/immutable/RedBlackTree.scala +++ b/src/library/scala/collection/immutable/RedBlackTree.scala @@ -45,11 +45,16 @@ object RedBlackTree { def count(tree: Tree[_, _]) = if (tree eq null) 0 else tree.count def update[A, B, B1 >: B](tree: Tree[A, B], k: A, v: B1)(implicit ordering: Ordering[A]): Tree[A, B1] = blacken(upd(tree, k, v)) def delete[A, B](tree: Tree[A, B], k: A)(implicit ordering: Ordering[A]): Tree[A, B] = blacken(del(tree, k)) - def range[A, B](tree: Tree[A, B], low: Option[A], lowInclusive: Boolean, high: Option[A], highInclusive: Boolean)(implicit ordering: Ordering[A]): Tree[A, B] = { - val after: Option[A => Boolean] = low.map(key => if (lowInclusive) ordering.lt(_, key) else ordering.lteq(_, key)) - val before: Option[A => Boolean] = high.map(key => if (highInclusive) ordering.lt(key, _) else ordering.lteq(key, _)) - blacken(rng(tree, after, before)) + def rangeImpl[A: Ordering, B](tree: Tree[A, B], from: Option[A], until: Option[A]): Tree[A, B] = (from, until) match { + case (Some(from), Some(until)) => this.range(tree, from, until) + case (Some(from), None) => this.from(tree, from) + case (None, Some(until)) => this.until(tree, until) + case (None, None) => tree } + def range[A: Ordering, B](tree: Tree[A, B], from: A, until: A): Tree[A, B] = blacken(doRange(tree, from, until)) + def from[A: Ordering, B](tree: Tree[A, B], from: A): Tree[A, B] = blacken(doFrom(tree, from)) + def to[A: Ordering, B](tree: Tree[A, B], to: A): Tree[A, B] = blacken(doTo(tree, to)) + def until[A: Ordering, B](tree: Tree[A, B], key: A): Tree[A, B] = blacken(doUntil(tree, key)) def smallest[A, B](tree: Tree[A, B]): Tree[A, B] = { if (tree eq null) throw new NoSuchElementException("empty map") @@ -202,13 +207,36 @@ object RedBlackTree { else append(tree.left, tree.right) } - private[this] def rng[A, B](tree: Tree[A, B], after: Option[A => Boolean], before: Option[A => Boolean])(implicit ordering: Ordering[A]): Tree[A, B] = { + private[this] def doFrom[A, B](tree: Tree[A, B], from: A)(implicit ordering: Ordering[A]): Tree[A, B] = { if (tree eq null) return null - if (after == None && before == None) return tree - if (after != None && after.get(tree.key)) return rng(tree.right, after, before); - if (before != None && before.get(tree.key)) return rng(tree.left, after, before); - val newLeft = rng(tree.left, after, None) - val newRight = rng(tree.right, None, before) + if (ordering.lt(tree.key, from)) return doFrom(tree.right, from) + val newLeft = doFrom(tree.left, from) + if (newLeft eq tree.left) tree + else if (newLeft eq null) upd(tree.right, tree.key, tree.value) + else rebalance(tree, newLeft, tree.right) + } + private[this] def doTo[A, B](tree: Tree[A, B], to: A)(implicit ordering: Ordering[A]): Tree[A, B] = { + if (tree eq null) return null + if (ordering.lt(to, tree.key)) return doTo(tree.left, to) + val newRight = doTo(tree.right, to) + if (newRight eq tree.right) tree + else if (newRight eq null) upd(tree.left, tree.key, tree.value) + else rebalance(tree, tree.left, newRight) + } + private[this] def doUntil[A, B](tree: Tree[A, B], until: A)(implicit ordering: Ordering[A]): Tree[A, B] = { + if (tree eq null) return null + if (ordering.lteq(until, tree.key)) return doUntil(tree.left, until) + val newRight = doUntil(tree.right, until) + if (newRight eq tree.right) tree + else if (newRight eq null) upd(tree.left, tree.key, tree.value) + else rebalance(tree, tree.left, newRight) + } + private[this] def doRange[A, B](tree: Tree[A, B], from: A, until: A)(implicit ordering: Ordering[A]): Tree[A, B] = { + if (tree eq null) return null + if (ordering.lt(tree.key, from)) return doRange(tree.right, from, until); + if (ordering.lteq(until, tree.key)) return doRange(tree.left, from, until); + val newLeft = doFrom(tree.left, from) + val newRight = doUntil(tree.right, until) if ((newLeft eq tree.left) && (newRight eq tree.right)) tree else if (newLeft eq null) upd(newRight, tree.key, tree.value); else if (newRight eq null) upd(newLeft, tree.key, tree.value); diff --git a/src/library/scala/collection/immutable/TreeMap.scala b/src/library/scala/collection/immutable/TreeMap.scala index 3eba64dca3..a24221decc 100644 --- a/src/library/scala/collection/immutable/TreeMap.scala +++ b/src/library/scala/collection/immutable/TreeMap.scala @@ -62,14 +62,11 @@ class TreeMap[A, +B] private (tree: RB.Tree[A, B])(implicit val ordering: Orderi def this()(implicit ordering: Ordering[A]) = this(null)(ordering) - override def rangeImpl(from: Option[A], until: Option[A]): TreeMap[A, B] = { - val ntree = RB.range(tree, from, true, until, false) - new TreeMap[A, B](ntree) - } - override def to(to: A): TreeMap[A, B] = { - val ntree = RB.range(tree, None, true, Some(to), true) - new TreeMap[A, B](ntree) - } + override def rangeImpl(from: Option[A], until: Option[A]): TreeMap[A, B] = new TreeMap[A, B](RB.rangeImpl(tree, from, until)) + override def range(from: A, until: A): TreeMap[A, B] = new TreeMap[A, B](RB.range(tree, from, until)) + override def from(from: A): TreeMap[A, B] = new TreeMap[A, B](RB.from(tree, from)) + override def to(to: A): TreeMap[A, B] = new TreeMap[A, B](RB.to(tree, to)) + override def until(until: A): TreeMap[A, B] = new TreeMap[A, B](RB.until(tree, until)) override def firstKey = RB.smallest(tree).key override def lastKey = RB.greatest(tree).key diff --git a/src/library/scala/collection/immutable/TreeSet.scala b/src/library/scala/collection/immutable/TreeSet.scala index 5dd80e87a4..e21aec362c 100644 --- a/src/library/scala/collection/immutable/TreeSet.scala +++ b/src/library/scala/collection/immutable/TreeSet.scala @@ -150,14 +150,12 @@ class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: Orderin override def foreach[U](f: A => U) = RB.foreachKey(tree, f) - override def rangeImpl(from: Option[A], until: Option[A]): TreeSet[A] = { - val ntree = RB.range(tree, from, true, until, false) - newSet(ntree) - } - override def to(to: A): TreeSet[A] = { - val ntree = RB.range(tree, None, true, Some(to), true) - newSet(ntree) - } + override def rangeImpl(from: Option[A], until: Option[A]): TreeSet[A] = newSet(RB.rangeImpl(tree, from, until)) + override def range(from: A, until: A): TreeSet[A] = newSet(RB.range(tree, from, until)) + override def from(from: A): TreeSet[A] = newSet(RB.from(tree, from)) + override def to(to: A): TreeSet[A] = newSet(RB.to(tree, to)) + override def until(until: A): TreeSet[A] = newSet(RB.until(tree, until)) + override def firstKey = head override def lastKey = last } diff --git a/test/files/scalacheck/redblacktree.scala b/test/files/scalacheck/redblacktree.scala index 14538c2352..e4b356c889 100644 --- a/test/files/scalacheck/redblacktree.scala +++ b/test/files/scalacheck/redblacktree.scala @@ -174,39 +174,33 @@ package scala.collection.immutable.redblacktree { object TestRange extends RedBlackTreeTest with RedBlackTreeInvariants { import RB._ - override type ModifyParm = (Option[Int], Boolean, Option[Int], Boolean) + override type ModifyParm = (Option[Int], Option[Int]) override def genParm(tree: Tree[String, Int]): Gen[ModifyParm] = for { from <- choose(0, iterator(tree).size) - fromInclusive <- oneOf(false, true) to <- choose(0, iterator(tree).size) suchThat (from <=) - toInclusive <- oneOf(false, true) optionalFrom <- oneOf(Some(from), None, Some(from)) // Double Some(n) to get around a bug optionalTo <- oneOf(Some(to), None, Some(to)) // Double Some(n) to get around a bug - } yield (optionalFrom, fromInclusive, optionalTo, toInclusive) + } yield (optionalFrom, optionalTo) override def modify(tree: Tree[String, Int], parm: ModifyParm): Tree[String, Int] = { val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) - val to = parm._3 flatMap (nodeAt(tree, _) map (_._1)) - range(tree, from, parm._2, to, parm._4) + val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) + rangeImpl(tree, from, to) } property("range boundaries respected") = forAll(genInput) { case (tree, parm, newTree) => val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) - val fromPredicate: String => String => Boolean = if (parm._2) (_ <=) else (_ <) - val to = parm._3 flatMap (nodeAt(tree, _) map (_._1)) - val toPredicate: String => String => Boolean = if (parm._4) (_ >=) else (_ >) - ("lower boundary" |: (from forall ( key => keysIterator(newTree) forall fromPredicate(key)))) && - ("upper boundary" |: (to forall ( key => keysIterator(newTree) forall toPredicate(key)))) + val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) + ("lower boundary" |: (from forall ( key => keysIterator(newTree) forall (key <=)))) && + ("upper boundary" |: (to forall ( key => keysIterator(newTree) forall (key >)))) } property("range returns all elements") = forAll(genInput) { case (tree, parm, newTree) => val from = parm._1 flatMap (nodeAt(tree, _) map (_._1)) - val fromPredicate: String => String => Boolean = if (parm._2) (_ >=) else (_ >) - val to = parm._3 flatMap (nodeAt(tree, _) map (_._1)) - val toPredicate: String => String => Boolean = if (parm._4) (_ <=) else (_ <) + val to = parm._2 flatMap (nodeAt(tree, _) map (_._1)) val filteredTree = (keysIterator(tree) - .filter(key => from forall fromPredicate(key)) - .filter(key => to forall toPredicate(key)) + .filter(key => from forall (key >=)) + .filter(key => to forall (key <)) .toList) filteredTree == keysIterator(newTree).toList } -- cgit v1.2.3 From 78374f340e71d8e8f71c5bcd11452b72c207068c Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Sun, 22 Jan 2012 21:17:29 +0100 Subject: Custom implementations of drop/take/slice. This mainly helps performance when comparing keys is expensive. --- .../scala/collection/immutable/RedBlackTree.scala | 39 +++++++++++++++++++++- .../scala/collection/immutable/TreeMap.scala | 6 ++-- .../scala/collection/immutable/TreeSet.scala | 6 ++-- test/files/scalacheck/treemap.scala | 18 ++++++++-- test/files/scalacheck/treeset.scala | 18 ++++++++-- 5 files changed, 75 insertions(+), 12 deletions(-) (limited to 'test/files/scalacheck') diff --git a/src/library/scala/collection/immutable/RedBlackTree.scala b/src/library/scala/collection/immutable/RedBlackTree.scala index 7110ca4194..731a0f7975 100644 --- a/src/library/scala/collection/immutable/RedBlackTree.scala +++ b/src/library/scala/collection/immutable/RedBlackTree.scala @@ -56,6 +56,10 @@ object RedBlackTree { def to[A: Ordering, B](tree: Tree[A, B], to: A): Tree[A, B] = blacken(doTo(tree, to)) def until[A: Ordering, B](tree: Tree[A, B], key: A): Tree[A, B] = blacken(doUntil(tree, key)) + def drop[A: Ordering, B](tree: Tree[A, B], n: Int): Tree[A, B] = blacken(doDrop(tree, n)) + def take[A: Ordering, B](tree: Tree[A, B], n: Int): Tree[A, B] = blacken(doTake(tree, n)) + def slice[A: Ordering, B](tree: Tree[A, B], from: Int, until: Int): Tree[A, B] = blacken(doSlice(tree, from, until)) + def smallest[A, B](tree: Tree[A, B]): Tree[A, B] = { if (tree eq null) throw new NoSuchElementException("empty map") var result = tree @@ -86,7 +90,7 @@ object RedBlackTree { @tailrec def nth[A, B](tree: Tree[A, B], n: Int): Tree[A, B] = { - val count = RedBlackTree.count(tree.left) + val count = this.count(tree.left) if (n < count) nth(tree.left, n) else if (n > count) nth(tree.right, n - count - 1) else tree @@ -243,6 +247,39 @@ object RedBlackTree { else rebalance(tree, newLeft, newRight) } + private[this] def doDrop[A: Ordering, B](tree: Tree[A, B], n: Int): Tree[A, B] = { + if (n <= 0) return tree + if (n >= this.count(tree)) return null + val count = this.count(tree.left) + if (n > count) return doDrop(tree.right, n - count - 1) + val newLeft = doDrop(tree.left, n) + if (newLeft eq tree.left) tree + else if (newLeft eq null) upd(tree.right, tree.key, tree.value) + else rebalance(tree, newLeft, tree.right) + } + private[this] def doTake[A: Ordering, B](tree: Tree[A, B], n: Int): Tree[A, B] = { + if (n <= 0) return null + if (n >= this.count(tree)) return tree + val count = this.count(tree.left) + if (n <= count) return doTake(tree.left, n) + val newRight = doTake(tree.right, n - count - 1) + if (newRight eq tree.right) tree + else if (newRight eq null) upd(tree.left, tree.key, tree.value) + else rebalance(tree, tree.left, newRight) + } + private[this] def doSlice[A: Ordering, B](tree: Tree[A, B], from: Int, until: Int): Tree[A, B] = { + if (tree eq null) return null + val count = this.count(tree.left) + if (from > count) return doSlice(tree.right, from - count - 1, until - count - 1) + if (until <= count) return doSlice(tree.left, from, until) + val newLeft = doDrop(tree.left, from) + val newRight = doTake(tree.right, until - count - 1) + if ((newLeft eq tree.left) && (newRight eq tree.right)) tree + else if (newLeft eq null) upd(newRight, tree.key, tree.value) + else if (newRight eq null) upd(newLeft, tree.key, tree.value) + else rebalance(tree, newLeft, newRight) + } + // The zipper returned might have been traversed left-most (always the left child) // or right-most (always the right child). Left trees are traversed right-most, // and right trees are traversed leftmost. diff --git a/src/library/scala/collection/immutable/TreeMap.scala b/src/library/scala/collection/immutable/TreeMap.scala index a24221decc..dc4f79be35 100644 --- a/src/library/scala/collection/immutable/TreeMap.scala +++ b/src/library/scala/collection/immutable/TreeMap.scala @@ -89,20 +89,20 @@ class TreeMap[A, +B] private (tree: RB.Tree[A, B])(implicit val ordering: Orderi override def drop(n: Int) = { if (n <= 0) this else if (n >= size) empty - else from(RB.nth(tree, n).key) + else new TreeMap(RB.drop(tree, n)) } override def take(n: Int) = { if (n <= 0) empty else if (n >= size) this - else until(RB.nth(tree, n).key) + else new TreeMap(RB.take(tree, n)) } override def slice(from: Int, until: Int) = { if (until <= from) empty else if (from <= 0) take(until) else if (until >= size) drop(from) - else range(RB.nth(tree, from).key, RB.nth(tree, until).key) + else new TreeMap(RB.slice(tree, from, until)) } override def dropRight(n: Int) = take(size - n) diff --git a/src/library/scala/collection/immutable/TreeSet.scala b/src/library/scala/collection/immutable/TreeSet.scala index e21aec362c..1b3d72ceb7 100644 --- a/src/library/scala/collection/immutable/TreeSet.scala +++ b/src/library/scala/collection/immutable/TreeSet.scala @@ -66,20 +66,20 @@ class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: Orderin override def drop(n: Int) = { if (n <= 0) this else if (n >= size) empty - else from(RB.nth(tree, n).key) + else newSet(RB.drop(tree, n)) } override def take(n: Int) = { if (n <= 0) empty else if (n >= size) this - else until(RB.nth(tree, n).key) + else newSet(RB.take(tree, n)) } override def slice(from: Int, until: Int) = { if (until <= from) empty else if (from <= 0) take(until) else if (until >= size) drop(from) - else range(RB.nth(tree, from).key, RB.nth(tree, until).key) + else newSet(RB.slice(tree, from, until)) } override def dropRight(n: Int) = take(size - n) diff --git a/test/files/scalacheck/treemap.scala b/test/files/scalacheck/treemap.scala index ba6d117fd4..f672637c57 100644 --- a/test/files/scalacheck/treemap.scala +++ b/test/files/scalacheck/treemap.scala @@ -7,11 +7,12 @@ import util._ import Buildable._ object Test extends Properties("TreeMap") { - implicit def arbTreeMap[A : Arbitrary : Ordering, B : Arbitrary]: Arbitrary[TreeMap[A, B]] = - Arbitrary(for { + def genTreeMap[A: Arbitrary: Ordering, B: Arbitrary]: Gen[TreeMap[A, B]] = + for { keys <- listOf(arbitrary[A]) values <- listOfN(keys.size, arbitrary[B]) - } yield TreeMap(keys zip values: _*)) + } yield TreeMap(keys zip values: _*) + implicit def arbTreeMap[A : Arbitrary : Ordering, B : Arbitrary] = Arbitrary(genTreeMap[A, B]) property("foreach/iterator consistency") = forAll { (subject: TreeMap[Int, String]) => val it = subject.iterator @@ -96,6 +97,17 @@ object Test extends Properties("TreeMap") { prefix == subject.take(n) && suffix == subject.drop(n) } + def genSliceParms = for { + tree <- genTreeMap[Int, String] + from <- choose(0, tree.size) + until <- choose(from, tree.size) + } yield (tree, from, until) + + property("slice") = forAll(genSliceParms) { case (subject, from, until) => + val slice = subject.slice(from, until) + slice.size == until - from && subject.toSeq == subject.take(from).toSeq ++ slice ++ subject.drop(until) + } + property("takeWhile") = forAll { (subject: TreeMap[Int, String]) => val result = subject.takeWhile(_._1 < 0) result.forall(_._1 < 0) && result == subject.take(result.size) diff --git a/test/files/scalacheck/treeset.scala b/test/files/scalacheck/treeset.scala index e6d1b50860..98e38c8219 100644 --- a/test/files/scalacheck/treeset.scala +++ b/test/files/scalacheck/treeset.scala @@ -6,8 +6,11 @@ import Arbitrary._ import util._ object Test extends Properties("TreeSet") { - implicit def arbTreeSet[A : Arbitrary : Ordering]: Arbitrary[TreeSet[A]] = - Arbitrary(listOf(arbitrary[A]) map (elements => TreeSet(elements: _*))) + def genTreeSet[A: Arbitrary: Ordering]: Gen[TreeSet[A]] = + for { + elements <- listOf(arbitrary[A]) + } yield TreeSet(elements: _*) + implicit def arbTreeSet[A : Arbitrary : Ordering]: Arbitrary[TreeSet[A]] = Arbitrary(genTreeSet) property("foreach/iterator consistency") = forAll { (subject: TreeSet[Int]) => val it = subject.iterator @@ -92,6 +95,17 @@ object Test extends Properties("TreeSet") { prefix == subject.take(n) && suffix == subject.drop(n) } + def genSliceParms = for { + tree <- genTreeSet[Int] + from <- choose(0, tree.size) + until <- choose(from, tree.size) + } yield (tree, from, until) + + property("slice") = forAll(genSliceParms) { case (subject, from, until) => + val slice = subject.slice(from, until) + slice.size == until - from && subject.toSeq == subject.take(from).toSeq ++ slice ++ subject.drop(until) + } + property("takeWhile") = forAll { (subject: TreeSet[Int]) => val result = subject.takeWhile(_ < 0) result.forall(_ < 0) && result == subject.take(result.size) -- cgit v1.2.3