From 95cb7bc7e3017a0004a61749c7d121371c4fe31b Mon Sep 17 00:00:00 2001 From: Erik Rozendaal Date: Sun, 18 Dec 2011 20:41:23 +0100 Subject: Implemented drop/take/slice/splitAt/dropRight/takeRight for TreeMap/TreeSet by splitting the underlying RedBlack tree. This makes the operation O(log n) instead of O(n) and allows more structural sharing. --- .../scala/collection/immutable/RedBlack.scala | 7 +++++++ .../scala/collection/immutable/TreeMap.scala | 23 ++++++++++++++++++++++ .../scala/collection/immutable/TreeSet.scala | 23 ++++++++++++++++++++++ 3 files changed, 53 insertions(+) (limited to 'src') diff --git a/src/library/scala/collection/immutable/RedBlack.scala b/src/library/scala/collection/immutable/RedBlack.scala index 534d476507..5ce2a29dc2 100644 --- a/src/library/scala/collection/immutable/RedBlack.scala +++ b/src/library/scala/collection/immutable/RedBlack.scala @@ -45,6 +45,7 @@ abstract class RedBlack[A] extends Serializable { def first : A def last : A def count : Int + protected[immutable] def nth(n: Int): NonEmpty[B] } abstract class NonEmpty[+B] extends Tree[B] with Serializable { def isEmpty = false @@ -256,6 +257,11 @@ abstract class RedBlack[A] extends Serializable { def first = if (left .isEmpty) key else left.first def last = if (right.isEmpty) key else right.last val count = 1 + left.count + right.count + protected[immutable] def nth(n: Int) = { + if (n < left.count) left.nth(n) + else if (n > left.count) right.nth(n - left.count - 1) + else this + } } case object Empty extends Tree[Nothing] { def isEmpty = true @@ -274,6 +280,7 @@ abstract class RedBlack[A] extends Serializable { def first = throw new NoSuchElementException("empty map") def last = throw new NoSuchElementException("empty map") def count = 0 + protected[immutable] def nth(n: Int) = throw new NoSuchElementException("empty map") } case class RedTree[+B](override val key: A, override val value: B, diff --git a/src/library/scala/collection/immutable/TreeMap.scala b/src/library/scala/collection/immutable/TreeMap.scala index 0e160ca50e..bc91bbe268 100644 --- a/src/library/scala/collection/immutable/TreeMap.scala +++ b/src/library/scala/collection/immutable/TreeMap.scala @@ -85,6 +85,29 @@ class TreeMap[A, +B](override val size: Int, t: RedBlack[A]#Tree[B])(implicit va override def tail = new TreeMap(size - 1, tree.delete(firstKey)) override def init = new TreeMap(size - 1, tree.delete(lastKey)) + override def drop(n: Int) = { + if (n <= 0) this + else if (n >= size) empty + else from(tree.nth(n).key) + } + + override def take(n: Int) = { + if (n <= 0) empty + else if (n >= size) this + else until(tree.nth(n).key) + } + + override def slice(from: Int, until: Int) = { + if (until <= from) empty + else if (from <= 0) take(until) + else if (until >= size) drop(from) + else range(tree.nth(from).key, tree.nth(until).key) + } + + override def dropRight(n: Int) = take(size - n) + override def takeRight(n: Int) = drop(size - n) + override def splitAt(n: Int) = (take(n), drop(n)) + /** A factory to create empty maps of the same type of keys. */ override def empty: TreeMap[A, B] = TreeMap.empty[A, B](ordering) diff --git a/src/library/scala/collection/immutable/TreeSet.scala b/src/library/scala/collection/immutable/TreeSet.scala index b969ecc0e8..dfaffcd581 100644 --- a/src/library/scala/collection/immutable/TreeSet.scala +++ b/src/library/scala/collection/immutable/TreeSet.scala @@ -61,6 +61,29 @@ class TreeSet[A](override val size: Int, t: RedBlack[A]#Tree[Unit]) override def tail = new TreeSet(size - 1, tree.delete(firstKey)) override def init = new TreeSet(size - 1, tree.delete(lastKey)) + override def drop(n: Int) = { + if (n <= 0) this + else if (n >= size) empty + else from(tree.nth(n).key) + } + + override def take(n: Int) = { + if (n <= 0) empty + else if (n >= size) this + else until(tree.nth(n).key) + } + + override def slice(from: Int, until: Int) = { + if (until <= from) empty + else if (from <= 0) take(until) + else if (until >= size) drop(from) + else range(tree.nth(from).key, tree.nth(until).key) + } + + override def dropRight(n: Int) = take(size - n) + override def takeRight(n: Int) = drop(size - n) + override def splitAt(n: Int) = (take(n), drop(n)) + def isSmaller(x: A, y: A) = compare(x,y) < 0 def this()(implicit ordering: Ordering[A]) = this(0, null)(ordering) -- cgit v1.2.3