From fe6886eb0ec9c02fa666e9e7af09bab92b985d05 Mon Sep 17 00:00:00 2001 From: Rui Gonçalves Date: Sun, 17 Apr 2016 17:51:17 +0100 Subject: Improve performance and behavior of ListMap and ListSet Makes the immutable `ListMap` and `ListSet` collections more alike one another, both in their semantics and in their performance. In terms of semantics, makes the `ListSet` iterator return the elements in their insertion order, as `ListMap` already does. While, as mentioned in SI-8985, `ListMap` and `ListSet` doesn't seem to make any guarantees in terms of iteration order, I believe users expect `ListSet` and `ListMap` to behave in the same way, particularly when they are implemented in the exact same way. In terms of performance, `ListSet` has a custom builder that avoids creation in O(N^2) time. However, this significantly reduces its performance in the creation of small sets, as its requires the instantiation and usage of an auxilliary HashSet. As `ListMap` and `ListSet` are only suitable for small sizes do to their performance characteristics, the builder is removed, the default `SetBuilder` being used instead. --- .../scala/collection/immutable/ListMap.scala | 34 +++++------ .../scala/collection/immutable/ListSet.scala | 67 +++++++--------------- test/files/jvm/serialization-new.check | 4 +- test/files/jvm/serialization.check | 4 +- test/files/run/t3822.scala | 19 ------ test/files/run/t6198.scala | 7 --- test/files/run/t7445.scala | 6 -- .../scala/collection/immutable/ListMapTest.scala | 48 ++++++++++++++++ .../scala/collection/immutable/ListSetTest.scala | 53 +++++++++++++++++ 9 files changed, 144 insertions(+), 98 deletions(-) delete mode 100644 test/files/run/t3822.scala delete mode 100644 test/files/run/t7445.scala create mode 100644 test/junit/scala/collection/immutable/ListMapTest.scala create mode 100644 test/junit/scala/collection/immutable/ListSetTest.scala diff --git a/src/library/scala/collection/immutable/ListMap.scala b/src/library/scala/collection/immutable/ListMap.scala index e1bcc0711c..9af05183dd 100644 --- a/src/library/scala/collection/immutable/ListMap.scala +++ b/src/library/scala/collection/immutable/ListMap.scala @@ -113,7 +113,8 @@ extends AbstractMap[A, B] * @param xs the traversable object. */ override def ++[B1 >: B](xs: GenTraversableOnce[(A, B1)]): ListMap[A, B1] = - ((repr: ListMap[A, B1]) /: xs.seq) (_ + _) + if (xs.isEmpty) this + else ((repr: ListMap[A, B1]) /: xs) (_ + _) /** This creates a new mapping without the given `key`. * If the map does not contain a mapping for the given key, the @@ -125,14 +126,18 @@ extends AbstractMap[A, B] /** Returns an iterator over key-value pairs. */ - def iterator: Iterator[(A,B)] = - new AbstractIterator[(A,B)] { - var self: ListMap[A,B] = ListMap.this - def hasNext = !self.isEmpty - def next(): (A,B) = - if (!hasNext) throw new NoSuchElementException("next on empty iterator") - else { val res = (self.key, self.value); self = self.next; res } - }.toList.reverseIterator + def iterator: Iterator[(A, B)] = { + def reverseList = { + var curr: ListMap[A, B] = this + var res: List[(A, B)] = Nil + while (!curr.isEmpty) { + res = (curr.key, curr.value) :: res + curr = curr.next + } + res + } + reverseList.iterator + } protected def key: A = throw new NoSuchElementException("empty map") protected def value: B = throw new NoSuchElementException("empty map") @@ -210,14 +215,9 @@ extends AbstractMap[A, B] override def - (k: A): ListMap[A, B1] = remove0(k, this, Nil) @tailrec private def remove0(k: A, cur: ListMap[A, B1], acc: List[ListMap[A, B1]]): ListMap[A, B1] = - if (cur.isEmpty) - acc.last - else if (k == cur.key) - (cur.next /: acc) { - case (t, h) => val tt = t; new tt.Node(h.key, h.value) // SI-7459 - } - else - remove0(k, cur.next, cur::acc) + if (cur.isEmpty) acc.last + else if (k == cur.key) (cur.next /: acc) { case (t, h) => new t.Node(h.key, h.value) } + else remove0(k, cur.next, cur::acc) override protected def next: ListMap[A, B1] = ListMap.this diff --git a/src/library/scala/collection/immutable/ListSet.scala b/src/library/scala/collection/immutable/ListSet.scala index d20e7bc6d2..7803e055ed 100644 --- a/src/library/scala/collection/immutable/ListSet.scala +++ b/src/library/scala/collection/immutable/ListSet.scala @@ -12,7 +12,6 @@ package immutable import generic._ import scala.annotation.tailrec -import mutable.{Builder, ReusableBuilder} /** $factoryInfo * @define Coll immutable.ListSet @@ -23,33 +22,8 @@ object ListSet extends ImmutableSetFactory[ListSet] { /** setCanBuildFromInfo */ implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, ListSet[A]] = setCanBuildFrom[A] - override def newBuilder[A]: Builder[A, ListSet[A]] = new ListSetBuilder[A] - private object EmptyListSet extends ListSet[Any] { } private[collection] def emptyInstance: ListSet[Any] = EmptyListSet - - /** A custom builder because forgetfully adding elements one at - * a time to a list backed set puts the "squared" in N^2. There is a - * temporary space cost, but it's improbable a list backed set could - * become large enough for this to matter given its pricy element lookup. - * - * This builder is reusable. - */ - class ListSetBuilder[Elem](initial: ListSet[Elem]) extends ReusableBuilder[Elem, ListSet[Elem]] { - def this() = this(empty[Elem]) - protected val elems = (new mutable.ListBuffer[Elem] ++= initial).reverse - protected val seen = new mutable.HashSet[Elem] ++= initial - - def +=(x: Elem): this.type = { - if (!seen(x)) { - elems += x - seen += x - } - this - } - def clear() = { elems.clear() ; seen.clear() } - def result() = elems.foldLeft(empty[Elem])(_ unchecked_+ _) - } } /** This class implements immutable sets using a list-based data @@ -104,9 +78,8 @@ sealed class ListSet[A] extends AbstractSet[A] */ override def ++(xs: GenTraversableOnce[A]): ListSet[A] = if (xs.isEmpty) this - else (new ListSet.ListSetBuilder(this) ++= xs.seq).result() + else (repr /: xs) (_ + _) - private[ListSet] def unchecked_+(e: A): ListSet[A] = new Node(e) private[ListSet] def unchecked_outer: ListSet[A] = throw new NoSuchElementException("Empty ListSet has no outer pointer") @@ -115,33 +88,34 @@ sealed class ListSet[A] extends AbstractSet[A] * @throws java.util.NoSuchElementException * @return the new iterator */ - def iterator: Iterator[A] = new AbstractIterator[A] { - var that: ListSet[A] = self - def hasNext = that.nonEmpty - def next: A = - if (hasNext) { - val res = that.head - that = that.tail - res + def iterator: Iterator[A] = { + def reverseList = { + var curr: ListSet[A] = self + var res: List[A] = Nil + while (!curr.isEmpty) { + res = curr.elem :: res + curr = curr.next } - else Iterator.empty.next() + res + } + reverseList.iterator } /** * @throws java.util.NoSuchElementException */ - override def head: A = throw new NoSuchElementException("Set has no elements") + protected def elem: A = throw new NoSuchElementException("elem of empty set") /** * @throws java.util.NoSuchElementException */ - override def tail: ListSet[A] = throw new NoSuchElementException("Next of an empty set") + protected def next: ListSet[A] = throw new NoSuchElementException("Next of an empty set") override def stringPrefix = "ListSet" /** Represents an entry in the `ListSet`. */ - protected class Node(override val head: A) extends ListSet[A] with Serializable { + protected class Node(override val elem: A) extends ListSet[A] with Serializable { override private[ListSet] def unchecked_outer = self /** Returns the number of elements in this set. @@ -166,7 +140,7 @@ sealed class ListSet[A] extends AbstractSet[A] */ override def contains(e: A) = containsInternal(this, e) @tailrec private def containsInternal(n: ListSet[A], e: A): Boolean = - !n.isEmpty && (n.head == e || containsInternal(n.unchecked_outer, e)) + !n.isEmpty && (n.elem == e || containsInternal(n.unchecked_outer, e)) /** This method creates a new set with an additional element. */ @@ -174,11 +148,14 @@ sealed class ListSet[A] extends AbstractSet[A] /** `-` can be used to remove a single element from a set. */ - override def -(e: A): ListSet[A] = if (e == head) self else { - val tail = self - e; new tail.Node(head) - } + override def -(e: A): ListSet[A] = removeInternal(e, this, Nil) + + @tailrec private def removeInternal(k: A, cur: ListSet[A], acc: List[ListSet[A]]): ListSet[A] = + if (cur.isEmpty) acc.last + else if (k == cur.elem) (cur.next /: acc) { case (t, h) => new t.Node(h.elem) } + else removeInternal(k, cur.next, cur :: acc) - override def tail: ListSet[A] = self + override protected def next: ListSet[A] = self } override def toSet[B >: A]: Set[B] = this.asInstanceOf[ListSet[B]] diff --git a/test/files/jvm/serialization-new.check b/test/files/jvm/serialization-new.check index cb26446f40..91248320d4 100644 --- a/test/files/jvm/serialization-new.check +++ b/test/files/jvm/serialization-new.check @@ -89,8 +89,8 @@ x = Map(buffers -> 20, layers -> 2, title -> 3) y = Map(buffers -> 20, layers -> 2, title -> 3) x equals y: true, y equals x: true -x = ListSet(5, 3) -y = ListSet(5, 3) +x = ListSet(3, 5) +y = ListSet(3, 5) x equals y: true, y equals x: true x = Queue(a, b, c) diff --git a/test/files/jvm/serialization.check b/test/files/jvm/serialization.check index cb26446f40..91248320d4 100644 --- a/test/files/jvm/serialization.check +++ b/test/files/jvm/serialization.check @@ -89,8 +89,8 @@ x = Map(buffers -> 20, layers -> 2, title -> 3) y = Map(buffers -> 20, layers -> 2, title -> 3) x equals y: true, y equals x: true -x = ListSet(5, 3) -y = ListSet(5, 3) +x = ListSet(3, 5) +y = ListSet(3, 5) x equals y: true, y equals x: true x = Queue(a, b, c) diff --git a/test/files/run/t3822.scala b/test/files/run/t3822.scala deleted file mode 100644 index c35804035e..0000000000 --- a/test/files/run/t3822.scala +++ /dev/null @@ -1,19 +0,0 @@ -import scala.collection.{ mutable, immutable, generic } -import immutable.ListSet - -object Test { - def main(args: Array[String]): Unit = { - val xs = ListSet(-100000 to 100001: _*) - - assert(xs.size == 200002) - assert(xs.sum == 100001) - - val ys = ListSet[Int]() - val ys1 = (1 to 12).grouped(3).foldLeft(ys)(_ ++ _) - val ys2 = (1 to 12).foldLeft(ys)(_ + _) - - assert(ys1 == ys2) - } -} - - diff --git a/test/files/run/t6198.scala b/test/files/run/t6198.scala index 5aa8f1c1cf..65dbaf8160 100644 --- a/test/files/run/t6198.scala +++ b/test/files/run/t6198.scala @@ -1,13 +1,6 @@ import scala.collection.immutable._ object Test extends App { - // test that ListSet.tail does not use a builder - // we can't test for O(1) behavior, so the best we can do is to - // check that ls.tail always returns the same instance - val ls = ListSet.empty[Int] + 1 + 2 - - if(ls.tail ne ls.tail) - println("ListSet.tail should not use a builder!") // class that always causes hash collisions case class Collision(value:Int) { override def hashCode = 0 } diff --git a/test/files/run/t7445.scala b/test/files/run/t7445.scala deleted file mode 100644 index e4ffeb8e1a..0000000000 --- a/test/files/run/t7445.scala +++ /dev/null @@ -1,6 +0,0 @@ -import scala.collection.immutable.ListMap - -object Test extends App { - val a = ListMap(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4, 5 -> 5); - require(a.tail == ListMap(2 -> 2, 3 -> 3, 4 -> 4, 5 -> 5)); -} diff --git a/test/junit/scala/collection/immutable/ListMapTest.scala b/test/junit/scala/collection/immutable/ListMapTest.scala new file mode 100644 index 0000000000..320a976755 --- /dev/null +++ b/test/junit/scala/collection/immutable/ListMapTest.scala @@ -0,0 +1,48 @@ +package scala.collection.immutable + +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + +@RunWith(classOf[JUnit4]) +class ListMapTest { + + @Test + def t7445(): Unit = { + val m = ListMap(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4, 5 -> 5) + assertEquals(ListMap(2 -> 2, 3 -> 3, 4 -> 4, 5 -> 5), m.tail) + } + + @Test + def hasCorrectBuilder(): Unit = { + val m = ListMap("a" -> "1", "b" -> "2", "c" -> "3", "b" -> "2.2", "d" -> "4") + assertEquals(List("a" -> "1", "c" -> "3", "b" -> "2.2", "d" -> "4"), m.toList) + } + + @Test + def hasCorrectHeadTailLastInit(): Unit = { + val m = ListMap(1 -> 1, 2 -> 2, 3 -> 3) + assertEquals(1 -> 1, m.head) + assertEquals(ListMap(2 -> 2, 3 -> 3), m.tail) + assertEquals(3 -> 3, m.last) + assertEquals(ListMap(1 -> 1, 2 -> 2), m.init) + } + + @Test + def hasCorrectAddRemove(): Unit = { + val m = ListMap(1 -> 1, 2 -> 2, 3 -> 3) + assertEquals(ListMap(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), m + (4 -> 4)) + assertEquals(ListMap(1 -> 1, 3 -> 3, 2 -> 4), m + (2 -> 4)) + assertEquals(ListMap(1 -> 1, 2 -> 2, 3 -> 3), m + (2 -> 2)) + assertEquals(ListMap(2 -> 2, 3 -> 3), m - 1) + assertEquals(ListMap(1 -> 1, 3 -> 3), m - 2) + assertEquals(ListMap(1 -> 1, 2 -> 2, 3 -> 3), m - 4) + } + + @Test + def hasCorrectIterator(): Unit = { + val m = ListMap(1 -> 1, 2 -> 2, 3 -> 3, 5 -> 5, 4 -> 4) + assertEquals(List(1 -> 1, 2 -> 2, 3 -> 3, 5 -> 5, 4 -> 4), m.iterator.toList) + } +} diff --git a/test/junit/scala/collection/immutable/ListSetTest.scala b/test/junit/scala/collection/immutable/ListSetTest.scala new file mode 100644 index 0000000000..395da88c75 --- /dev/null +++ b/test/junit/scala/collection/immutable/ListSetTest.scala @@ -0,0 +1,53 @@ +package scala.collection.immutable + +import org.junit.Assert._ +import org.junit.Test +import org.junit.runner.RunWith +import org.junit.runners.JUnit4 + +@RunWith(classOf[JUnit4]) +class ListSetTest { + + @Test + def t7445(): Unit = { + val s = ListSet(1, 2, 3, 4, 5) + assertEquals(ListSet(2, 3, 4, 5), s.tail) + } + + @Test + def hasCorrectBuilder(): Unit = { + val m = ListSet("a", "b", "c", "b", "d") + assertEquals(List("a", "b", "c", "d"), m.toList) + } + + @Test + def hasTailRecursiveDelete(): Unit = { + val s = ListSet(1 to 50000: _*) + try s - 25000 catch { case e: StackOverflowError => fail("A stack overflow occurred") } + } + + @Test + def hasCorrectHeadTailLastInit(): Unit = { + val m = ListSet(1, 2, 3) + assertEquals(1, m.head) + assertEquals(ListSet(2, 3), m.tail) + assertEquals(3, m.last) + assertEquals(ListSet(1, 2), m.init) + } + + @Test + def hasCorrectAddRemove(): Unit = { + val m = ListSet(1, 2, 3) + assertEquals(ListSet(1, 2, 3, 4), m + 4) + assertEquals(ListSet(1, 2, 3), m + 2) + assertEquals(ListSet(2, 3), m - 1) + assertEquals(ListSet(1, 3), m - 2) + assertEquals(ListSet(1, 2, 3), m - 4) + } + + @Test + def hasCorrectIterator(): Unit = { + val s = ListSet(1, 2, 3, 5, 4) + assertEquals(List(1, 2, 3, 5, 4), s.iterator.toList) + } +} -- cgit v1.2.3