From a01f074d3e3229db905e21e78469ab02dc6ca0cf Mon Sep 17 00:00:00 2001 From: Aleksandar Pokopec Date: Fri, 19 Nov 2010 12:38:11 +0000 Subject: Applying patch by Daniel Sobral for #3796. No review. --- .../scala/collection/immutable/RedBlack.scala | 144 ++++++++++++++++----- 1 file changed, 111 insertions(+), 33 deletions(-) (limited to 'src/library') diff --git a/src/library/scala/collection/immutable/RedBlack.scala b/src/library/scala/collection/immutable/RedBlack.scala index 8903e5a125..ae72c3ae69 100644 --- a/src/library/scala/collection/immutable/RedBlack.scala +++ b/src/library/scala/collection/immutable/RedBlack.scala @@ -34,6 +34,7 @@ abstract class RedBlack[A] { def lookup(x: A): Tree[B] def update[B1 >: B](k: A, v: B1): Tree[B1] = blacken(upd(k, v)) def delete(k: A): Tree[B] = blacken(del(k)) + def range(from: Option[A], until: Option[A]): Tree[B] = blacken(rng(from, until)) def foreach[U](f: (A, B) => U) @deprecated("use `foreach' instead") def visit[T](input: T)(f: (T, A, B) => (Boolean, T)): (Boolean, T) @@ -43,7 +44,7 @@ abstract class RedBlack[A] { def upd[B1 >: B](k: A, v: B1): Tree[B1] def del(k: A): Tree[B] def smallest: NonEmpty[B] - def range(from: Option[A], until: Option[A]): Tree[B] + def rng(from: Option[A], until: Option[A]): Tree[B] def first : A def last : A def count : Int @@ -59,23 +60,23 @@ abstract class RedBlack[A] { if (isSmaller(k, key)) left.lookup(k) else if (isSmaller(key, k)) right.lookup(k) else this + private[this] def balanceLeft[B1 >: B](isBlack: Boolean, z: A, zv: B, l: Tree[B1], d: Tree[B1])/*: NonEmpty[B1]*/ = l match { + case RedTree(y, yv, RedTree(x, xv, a, b), c) => + RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) + case RedTree(x, xv, a, RedTree(y, yv, b, c)) => + RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) + case _ => + mkTree(isBlack, z, zv, l, d) + } + private[this] def balanceRight[B1 >: B](isBlack: Boolean, x: A, xv: B, a: Tree[B1], r: Tree[B1])/*: NonEmpty[B1]*/ = r match { + case RedTree(z, zv, RedTree(y, yv, b, c), d) => + RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) + case RedTree(y, yv, b, RedTree(z, zv, c, d)) => + RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) + case _ => + mkTree(isBlack, x, xv, a, r) + } def upd[B1 >: B](k: A, v: B1): Tree[B1] = { - def balanceLeft(isBlack: Boolean, z: A, zv: B, l: Tree[B1], d: Tree[B1]) = l match { - case RedTree(y, yv, RedTree(x, xv, a, b), c) => - RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) - case RedTree(x, xv, a, RedTree(y, yv, b, c)) => - RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) - case _ => - mkTree(isBlack, z, zv, l, d) - } - def balanceRight(isBlack: Boolean, x: A, xv: B, a: Tree[B1], r: Tree[B1]) = r match { - case RedTree(z, zv, RedTree(y, yv, b, c), d) => - RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) - case RedTree(y, yv, b, RedTree(z, zv, c, d)) => - RedTree(y, yv, BlackTree(x, xv, a, b), BlackTree(z, zv, c, d)) - case _ => - mkTree(isBlack, x, xv, a, r) - } if (isSmaller(k, key)) balanceLeft(isBlack, key, value, left.upd(k, v), right) else if (isSmaller(key, k)) balanceRight(isBlack, key, value, left, right.upd(k, v)) else mkTree(isBlack, k, v, left, right) @@ -173,21 +174,97 @@ abstract class RedBlack[A] { if (!middle._1) return middle return this.right.visit(middle._2)(f) } - override def range(from: Option[A], until: Option[A]): Tree[B] = { - // if (from == None && until == None) return this - // if (from != None && isSmaller(key, from.get)) return right.range(from, until); - // if (until != None && (isSmaller(until.get,key) || !isSmaller(key,until.get))) - // return left.range(from, until); - // val newLeft = left.range(from, None) - // val newRight = right.range(None, until) - // if ((newLeft eq left) && (newRight eq right)) this - // else if (newLeft eq Empty) newRight.upd(key, value); - // else if (newRight eq Empty) newLeft.upd(key, value); - // else mkTree(isBlack, key, value, newLeft, newRight) - iterator - .dropWhile { case (key, _) => from exists (isSmaller(key, _)) } - .takeWhile { case (key, _) => until forall (isSmaller(_, key)) } - .foldLeft(Empty: Tree[B]) { case (tree, (key, value)) => tree.update(key, value) } + override def rng(from: Option[A], until: Option[A]): Tree[B] = { + if (from == None && until == None) return this + if (from != None && isSmaller(key, from.get)) return right.rng(from, until); + if (until != None && (isSmaller(until.get,key) || !isSmaller(key,until.get))) + return left.rng(from, until); + val newLeft = left.rng(from, None) + val newRight = right.rng(None, until) + if ((newLeft eq left) && (newRight eq right)) this + else if (newLeft eq Empty) newRight.upd(key, value); + else if (newRight eq Empty) newLeft.upd(key, value); + else rebalance(newLeft, newRight) + } + + // The zipper returned might have been traversed left-most (always the left child) + // or right-most (always the right child). Left trees are traversed right-most, + // and right trees are traversed leftmost. + + // Returns the zipper for the side with deepest black nodes depth, a flag + // indicating whether the trees were unbalanced at all, and a flag indicating + // whether the zipper was traversed left-most or right-most. + + // If the trees were balanced, returns an empty zipper + private[this] def compareDepth(left: Tree[B], right: Tree[B]): (List[NonEmpty[B]], Boolean, Boolean, Int) = { + // Once a side is found to be deeper, unzip it to the bottom + def unzip(zipper: List[NonEmpty[B]], leftMost: Boolean): List[NonEmpty[B]] = { + val next = if (leftMost) zipper.head.left else zipper.head.right + next match { + case node: NonEmpty[B] => unzip(node :: zipper, leftMost) + case Empty => zipper + } + } + + // Unzip left tree on the rightmost side and right tree on the leftmost side until one is + // found to be deeper, or the bottom is reached + def unzipBoth(left: Tree[B], + right: Tree[B], + leftZipper: List[NonEmpty[B]], + rightZipper: List[NonEmpty[B]], + smallerDepth: Int): (List[NonEmpty[B]], Boolean, Boolean, Int) = (left, right) match { + case (l: BlackTree[B], r: BlackTree[B]) => + unzipBoth(l.right, r.left, l :: leftZipper, r :: rightZipper, smallerDepth + 1) + case (l: RedTree[B], r: RedTree[B]) => + unzipBoth(l.right, r.left, l :: leftZipper, r :: rightZipper, smallerDepth) + case (_, r: RedTree[B]) => + unzipBoth(left, r.left, leftZipper, r :: rightZipper, smallerDepth) + case (l: RedTree[B], _) => + unzipBoth(l.right, right, l :: leftZipper, rightZipper, smallerDepth) + case (Empty, Empty) => + (Nil, true, false, smallerDepth) + case (Empty, r: BlackTree[B]) => + val leftMost = true + (unzip(r :: rightZipper, leftMost), false, leftMost, smallerDepth) + case (l: BlackTree[B], Empty) => + val leftMost = false + (unzip(l :: leftZipper, leftMost), false, leftMost, smallerDepth) + } + unzipBoth(left, right, Nil, Nil, 0) + } + + private[this] def rebalance(newLeft: Tree[B], newRight: Tree[B]) = { + // This is like drop(n-1), but only counting black nodes + def findDepth(zipper: List[NonEmpty[B]], depth: Int): List[NonEmpty[B]] = zipper match { + case (_: BlackTree[B]) :: tail => + if (depth == 1) zipper else findDepth(tail, depth - 1) + case _ :: tail => findDepth(tail, depth) + case Nil => error("Defect: unexpected empty zipper while computing range") + } + + // Blackening the smaller tree avoids balancing problems on union; + // this can't be done later, though, or it would change the result of compareDepth + val blkNewLeft = blacken(newLeft) + val blkNewRight = blacken(newRight) + val (zipper, levelled, leftMost, smallerDepth) = compareDepth(blkNewLeft, blkNewRight) + + if (levelled) { + BlackTree(key, value, blkNewLeft, blkNewRight) + } else { + val zipFrom = findDepth(zipper, smallerDepth) + val union = if (leftMost) { + RedTree(key, value, blkNewLeft, zipFrom.head) + } else { + RedTree(key, value, zipFrom.head, blkNewRight) + } + val zippedTree = zipFrom.tail.foldLeft(union: Tree[B]) { (tree, node) => + if (leftMost) + balanceLeft(node.isBlack, node.key, node.value, tree, node.right) + else + balanceRight(node.isBlack, node.key, node.value, node.left, tree) + } + zippedTree + } } def first = if (left .isEmpty) key else left.first def last = if (right.isEmpty) key else right.last @@ -209,7 +286,7 @@ abstract class RedBlack[A] { @deprecated("use `foreach' instead") def visit[T](input: T)(f: (T, A, Nothing) => (Boolean, T)) = (true, input) - def range(from: Option[A], until: Option[A]) = this + def rng(from: Option[A], until: Option[A]) = this def first = throw new NoSuchElementException("empty map") def last = throw new NoSuchElementException("empty map") def count = 0 @@ -230,3 +307,4 @@ abstract class RedBlack[A] { } } + -- cgit v1.2.3