summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRui Gonçalves <ruippeixotog@gmail.com>2016-04-17 17:51:17 +0100
committerRui Gonçalves <ruippeixotog@gmail.com>2016-05-17 10:55:16 +0100
commitfe6886eb0ec9c02fa666e9e7af09bab92b985d05 (patch)
treeaadd0a52ec583ff7fbd45ec95e611781adeee291
parent4c4c5e61a3b24e44247380eaf0519ee46036431a (diff)
downloadscala-fe6886eb0ec9c02fa666e9e7af09bab92b985d05.tar.gz
scala-fe6886eb0ec9c02fa666e9e7af09bab92b985d05.tar.bz2
scala-fe6886eb0ec9c02fa666e9e7af09bab92b985d05.zip
Improve performance and behavior of ListMap and ListSet
Makes the immutable `ListMap` and `ListSet` collections more alike one another, both in their semantics and in their performance. In terms of semantics, makes the `ListSet` iterator return the elements in their insertion order, as `ListMap` already does. While, as mentioned in SI-8985, `ListMap` and `ListSet` doesn't seem to make any guarantees in terms of iteration order, I believe users expect `ListSet` and `ListMap` to behave in the same way, particularly when they are implemented in the exact same way. In terms of performance, `ListSet` has a custom builder that avoids creation in O(N^2) time. However, this significantly reduces its performance in the creation of small sets, as its requires the instantiation and usage of an auxilliary HashSet. As `ListMap` and `ListSet` are only suitable for small sizes do to their performance characteristics, the builder is removed, the default `SetBuilder` being used instead.
-rw-r--r--src/library/scala/collection/immutable/ListMap.scala34
-rw-r--r--src/library/scala/collection/immutable/ListSet.scala67
-rw-r--r--test/files/jvm/serialization-new.check4
-rw-r--r--test/files/jvm/serialization.check4
-rw-r--r--test/files/run/t3822.scala19
-rw-r--r--test/files/run/t6198.scala7
-rw-r--r--test/files/run/t7445.scala6
-rw-r--r--test/junit/scala/collection/immutable/ListMapTest.scala48
-rw-r--r--test/junit/scala/collection/immutable/ListSetTest.scala53
9 files changed, 144 insertions, 98 deletions
diff --git a/src/library/scala/collection/immutable/ListMap.scala b/src/library/scala/collection/immutable/ListMap.scala
index e1bcc0711c..9af05183dd 100644
--- a/src/library/scala/collection/immutable/ListMap.scala
+++ b/src/library/scala/collection/immutable/ListMap.scala
@@ -113,7 +113,8 @@ extends AbstractMap[A, B]
* @param xs the traversable object.
*/
override def ++[B1 >: B](xs: GenTraversableOnce[(A, B1)]): ListMap[A, B1] =
- ((repr: ListMap[A, B1]) /: xs.seq) (_ + _)
+ if (xs.isEmpty) this
+ else ((repr: ListMap[A, B1]) /: xs) (_ + _)
/** This creates a new mapping without the given `key`.
* If the map does not contain a mapping for the given key, the
@@ -125,14 +126,18 @@ extends AbstractMap[A, B]
/** Returns an iterator over key-value pairs.
*/
- def iterator: Iterator[(A,B)] =
- new AbstractIterator[(A,B)] {
- var self: ListMap[A,B] = ListMap.this
- def hasNext = !self.isEmpty
- def next(): (A,B) =
- if (!hasNext) throw new NoSuchElementException("next on empty iterator")
- else { val res = (self.key, self.value); self = self.next; res }
- }.toList.reverseIterator
+ def iterator: Iterator[(A, B)] = {
+ def reverseList = {
+ var curr: ListMap[A, B] = this
+ var res: List[(A, B)] = Nil
+ while (!curr.isEmpty) {
+ res = (curr.key, curr.value) :: res
+ curr = curr.next
+ }
+ res
+ }
+ reverseList.iterator
+ }
protected def key: A = throw new NoSuchElementException("empty map")
protected def value: B = throw new NoSuchElementException("empty map")
@@ -210,14 +215,9 @@ extends AbstractMap[A, B]
override def - (k: A): ListMap[A, B1] = remove0(k, this, Nil)
@tailrec private def remove0(k: A, cur: ListMap[A, B1], acc: List[ListMap[A, B1]]): ListMap[A, B1] =
- if (cur.isEmpty)
- acc.last
- else if (k == cur.key)
- (cur.next /: acc) {
- case (t, h) => val tt = t; new tt.Node(h.key, h.value) // SI-7459
- }
- else
- remove0(k, cur.next, cur::acc)
+ if (cur.isEmpty) acc.last
+ else if (k == cur.key) (cur.next /: acc) { case (t, h) => new t.Node(h.key, h.value) }
+ else remove0(k, cur.next, cur::acc)
override protected def next: ListMap[A, B1] = ListMap.this
diff --git a/src/library/scala/collection/immutable/ListSet.scala b/src/library/scala/collection/immutable/ListSet.scala
index d20e7bc6d2..7803e055ed 100644
--- a/src/library/scala/collection/immutable/ListSet.scala
+++ b/src/library/scala/collection/immutable/ListSet.scala
@@ -12,7 +12,6 @@ package immutable
import generic._
import scala.annotation.tailrec
-import mutable.{Builder, ReusableBuilder}
/** $factoryInfo
* @define Coll immutable.ListSet
@@ -23,33 +22,8 @@ object ListSet extends ImmutableSetFactory[ListSet] {
/** setCanBuildFromInfo */
implicit def canBuildFrom[A]: CanBuildFrom[Coll, A, ListSet[A]] = setCanBuildFrom[A]
- override def newBuilder[A]: Builder[A, ListSet[A]] = new ListSetBuilder[A]
-
private object EmptyListSet extends ListSet[Any] { }
private[collection] def emptyInstance: ListSet[Any] = EmptyListSet
-
- /** A custom builder because forgetfully adding elements one at
- * a time to a list backed set puts the "squared" in N^2. There is a
- * temporary space cost, but it's improbable a list backed set could
- * become large enough for this to matter given its pricy element lookup.
- *
- * This builder is reusable.
- */
- class ListSetBuilder[Elem](initial: ListSet[Elem]) extends ReusableBuilder[Elem, ListSet[Elem]] {
- def this() = this(empty[Elem])
- protected val elems = (new mutable.ListBuffer[Elem] ++= initial).reverse
- protected val seen = new mutable.HashSet[Elem] ++= initial
-
- def +=(x: Elem): this.type = {
- if (!seen(x)) {
- elems += x
- seen += x
- }
- this
- }
- def clear() = { elems.clear() ; seen.clear() }
- def result() = elems.foldLeft(empty[Elem])(_ unchecked_+ _)
- }
}
/** This class implements immutable sets using a list-based data
@@ -104,9 +78,8 @@ sealed class ListSet[A] extends AbstractSet[A]
*/
override def ++(xs: GenTraversableOnce[A]): ListSet[A] =
if (xs.isEmpty) this
- else (new ListSet.ListSetBuilder(this) ++= xs.seq).result()
+ else (repr /: xs) (_ + _)
- private[ListSet] def unchecked_+(e: A): ListSet[A] = new Node(e)
private[ListSet] def unchecked_outer: ListSet[A] =
throw new NoSuchElementException("Empty ListSet has no outer pointer")
@@ -115,33 +88,34 @@ sealed class ListSet[A] extends AbstractSet[A]
* @throws java.util.NoSuchElementException
* @return the new iterator
*/
- def iterator: Iterator[A] = new AbstractIterator[A] {
- var that: ListSet[A] = self
- def hasNext = that.nonEmpty
- def next: A =
- if (hasNext) {
- val res = that.head
- that = that.tail
- res
+ def iterator: Iterator[A] = {
+ def reverseList = {
+ var curr: ListSet[A] = self
+ var res: List[A] = Nil
+ while (!curr.isEmpty) {
+ res = curr.elem :: res
+ curr = curr.next
}
- else Iterator.empty.next()
+ res
+ }
+ reverseList.iterator
}
/**
* @throws java.util.NoSuchElementException
*/
- override def head: A = throw new NoSuchElementException("Set has no elements")
+ protected def elem: A = throw new NoSuchElementException("elem of empty set")
/**
* @throws java.util.NoSuchElementException
*/
- override def tail: ListSet[A] = throw new NoSuchElementException("Next of an empty set")
+ protected def next: ListSet[A] = throw new NoSuchElementException("Next of an empty set")
override def stringPrefix = "ListSet"
/** Represents an entry in the `ListSet`.
*/
- protected class Node(override val head: A) extends ListSet[A] with Serializable {
+ protected class Node(override val elem: A) extends ListSet[A] with Serializable {
override private[ListSet] def unchecked_outer = self
/** Returns the number of elements in this set.
@@ -166,7 +140,7 @@ sealed class ListSet[A] extends AbstractSet[A]
*/
override def contains(e: A) = containsInternal(this, e)
@tailrec private def containsInternal(n: ListSet[A], e: A): Boolean =
- !n.isEmpty && (n.head == e || containsInternal(n.unchecked_outer, e))
+ !n.isEmpty && (n.elem == e || containsInternal(n.unchecked_outer, e))
/** This method creates a new set with an additional element.
*/
@@ -174,11 +148,14 @@ sealed class ListSet[A] extends AbstractSet[A]
/** `-` can be used to remove a single element from a set.
*/
- override def -(e: A): ListSet[A] = if (e == head) self else {
- val tail = self - e; new tail.Node(head)
- }
+ override def -(e: A): ListSet[A] = removeInternal(e, this, Nil)
+
+ @tailrec private def removeInternal(k: A, cur: ListSet[A], acc: List[ListSet[A]]): ListSet[A] =
+ if (cur.isEmpty) acc.last
+ else if (k == cur.elem) (cur.next /: acc) { case (t, h) => new t.Node(h.elem) }
+ else removeInternal(k, cur.next, cur :: acc)
- override def tail: ListSet[A] = self
+ override protected def next: ListSet[A] = self
}
override def toSet[B >: A]: Set[B] = this.asInstanceOf[ListSet[B]]
diff --git a/test/files/jvm/serialization-new.check b/test/files/jvm/serialization-new.check
index cb26446f40..91248320d4 100644
--- a/test/files/jvm/serialization-new.check
+++ b/test/files/jvm/serialization-new.check
@@ -89,8 +89,8 @@ x = Map(buffers -> 20, layers -> 2, title -> 3)
y = Map(buffers -> 20, layers -> 2, title -> 3)
x equals y: true, y equals x: true
-x = ListSet(5, 3)
-y = ListSet(5, 3)
+x = ListSet(3, 5)
+y = ListSet(3, 5)
x equals y: true, y equals x: true
x = Queue(a, b, c)
diff --git a/test/files/jvm/serialization.check b/test/files/jvm/serialization.check
index cb26446f40..91248320d4 100644
--- a/test/files/jvm/serialization.check
+++ b/test/files/jvm/serialization.check
@@ -89,8 +89,8 @@ x = Map(buffers -> 20, layers -> 2, title -> 3)
y = Map(buffers -> 20, layers -> 2, title -> 3)
x equals y: true, y equals x: true
-x = ListSet(5, 3)
-y = ListSet(5, 3)
+x = ListSet(3, 5)
+y = ListSet(3, 5)
x equals y: true, y equals x: true
x = Queue(a, b, c)
diff --git a/test/files/run/t3822.scala b/test/files/run/t3822.scala
deleted file mode 100644
index c35804035e..0000000000
--- a/test/files/run/t3822.scala
+++ /dev/null
@@ -1,19 +0,0 @@
-import scala.collection.{ mutable, immutable, generic }
-import immutable.ListSet
-
-object Test {
- def main(args: Array[String]): Unit = {
- val xs = ListSet(-100000 to 100001: _*)
-
- assert(xs.size == 200002)
- assert(xs.sum == 100001)
-
- val ys = ListSet[Int]()
- val ys1 = (1 to 12).grouped(3).foldLeft(ys)(_ ++ _)
- val ys2 = (1 to 12).foldLeft(ys)(_ + _)
-
- assert(ys1 == ys2)
- }
-}
-
-
diff --git a/test/files/run/t6198.scala b/test/files/run/t6198.scala
index 5aa8f1c1cf..65dbaf8160 100644
--- a/test/files/run/t6198.scala
+++ b/test/files/run/t6198.scala
@@ -1,13 +1,6 @@
import scala.collection.immutable._
object Test extends App {
- // test that ListSet.tail does not use a builder
- // we can't test for O(1) behavior, so the best we can do is to
- // check that ls.tail always returns the same instance
- val ls = ListSet.empty[Int] + 1 + 2
-
- if(ls.tail ne ls.tail)
- println("ListSet.tail should not use a builder!")
// class that always causes hash collisions
case class Collision(value:Int) { override def hashCode = 0 }
diff --git a/test/files/run/t7445.scala b/test/files/run/t7445.scala
deleted file mode 100644
index e4ffeb8e1a..0000000000
--- a/test/files/run/t7445.scala
+++ /dev/null
@@ -1,6 +0,0 @@
-import scala.collection.immutable.ListMap
-
-object Test extends App {
- val a = ListMap(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4, 5 -> 5);
- require(a.tail == ListMap(2 -> 2, 3 -> 3, 4 -> 4, 5 -> 5));
-}
diff --git a/test/junit/scala/collection/immutable/ListMapTest.scala b/test/junit/scala/collection/immutable/ListMapTest.scala
new file mode 100644
index 0000000000..320a976755
--- /dev/null
+++ b/test/junit/scala/collection/immutable/ListMapTest.scala
@@ -0,0 +1,48 @@
+package scala.collection.immutable
+
+import org.junit.Assert._
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+@RunWith(classOf[JUnit4])
+class ListMapTest {
+
+ @Test
+ def t7445(): Unit = {
+ val m = ListMap(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4, 5 -> 5)
+ assertEquals(ListMap(2 -> 2, 3 -> 3, 4 -> 4, 5 -> 5), m.tail)
+ }
+
+ @Test
+ def hasCorrectBuilder(): Unit = {
+ val m = ListMap("a" -> "1", "b" -> "2", "c" -> "3", "b" -> "2.2", "d" -> "4")
+ assertEquals(List("a" -> "1", "c" -> "3", "b" -> "2.2", "d" -> "4"), m.toList)
+ }
+
+ @Test
+ def hasCorrectHeadTailLastInit(): Unit = {
+ val m = ListMap(1 -> 1, 2 -> 2, 3 -> 3)
+ assertEquals(1 -> 1, m.head)
+ assertEquals(ListMap(2 -> 2, 3 -> 3), m.tail)
+ assertEquals(3 -> 3, m.last)
+ assertEquals(ListMap(1 -> 1, 2 -> 2), m.init)
+ }
+
+ @Test
+ def hasCorrectAddRemove(): Unit = {
+ val m = ListMap(1 -> 1, 2 -> 2, 3 -> 3)
+ assertEquals(ListMap(1 -> 1, 2 -> 2, 3 -> 3, 4 -> 4), m + (4 -> 4))
+ assertEquals(ListMap(1 -> 1, 3 -> 3, 2 -> 4), m + (2 -> 4))
+ assertEquals(ListMap(1 -> 1, 2 -> 2, 3 -> 3), m + (2 -> 2))
+ assertEquals(ListMap(2 -> 2, 3 -> 3), m - 1)
+ assertEquals(ListMap(1 -> 1, 3 -> 3), m - 2)
+ assertEquals(ListMap(1 -> 1, 2 -> 2, 3 -> 3), m - 4)
+ }
+
+ @Test
+ def hasCorrectIterator(): Unit = {
+ val m = ListMap(1 -> 1, 2 -> 2, 3 -> 3, 5 -> 5, 4 -> 4)
+ assertEquals(List(1 -> 1, 2 -> 2, 3 -> 3, 5 -> 5, 4 -> 4), m.iterator.toList)
+ }
+}
diff --git a/test/junit/scala/collection/immutable/ListSetTest.scala b/test/junit/scala/collection/immutable/ListSetTest.scala
new file mode 100644
index 0000000000..395da88c75
--- /dev/null
+++ b/test/junit/scala/collection/immutable/ListSetTest.scala
@@ -0,0 +1,53 @@
+package scala.collection.immutable
+
+import org.junit.Assert._
+import org.junit.Test
+import org.junit.runner.RunWith
+import org.junit.runners.JUnit4
+
+@RunWith(classOf[JUnit4])
+class ListSetTest {
+
+ @Test
+ def t7445(): Unit = {
+ val s = ListSet(1, 2, 3, 4, 5)
+ assertEquals(ListSet(2, 3, 4, 5), s.tail)
+ }
+
+ @Test
+ def hasCorrectBuilder(): Unit = {
+ val m = ListSet("a", "b", "c", "b", "d")
+ assertEquals(List("a", "b", "c", "d"), m.toList)
+ }
+
+ @Test
+ def hasTailRecursiveDelete(): Unit = {
+ val s = ListSet(1 to 50000: _*)
+ try s - 25000 catch { case e: StackOverflowError => fail("A stack overflow occurred") }
+ }
+
+ @Test
+ def hasCorrectHeadTailLastInit(): Unit = {
+ val m = ListSet(1, 2, 3)
+ assertEquals(1, m.head)
+ assertEquals(ListSet(2, 3), m.tail)
+ assertEquals(3, m.last)
+ assertEquals(ListSet(1, 2), m.init)
+ }
+
+ @Test
+ def hasCorrectAddRemove(): Unit = {
+ val m = ListSet(1, 2, 3)
+ assertEquals(ListSet(1, 2, 3, 4), m + 4)
+ assertEquals(ListSet(1, 2, 3), m + 2)
+ assertEquals(ListSet(2, 3), m - 1)
+ assertEquals(ListSet(1, 3), m - 2)
+ assertEquals(ListSet(1, 2, 3), m - 4)
+ }
+
+ @Test
+ def hasCorrectIterator(): Unit = {
+ val s = ListSet(1, 2, 3, 5, 4)
+ assertEquals(List(1, 2, 3, 5, 4), s.iterator.toList)
+ }
+}