summaryrefslogtreecommitdiff
path: root/test
diff options
context:
space:
mode:
authorAleksandar Prokopec <axel22@gmail.com>2012-02-01 19:54:50 +0100
committerAleksandar Prokopec <axel22@gmail.com>2012-02-01 19:54:50 +0100
commit5fe2d8b109abf3ff3e2d82dd4f248200846795c3 (patch)
treeb50c45368759198fdcd0521016138a3fd7019322 /test
parent8aa87f15e3887dbeb1a39bfea002b56cf68c445a (diff)
downloadscala-5fe2d8b109abf3ff3e2d82dd4f248200846795c3.tar.gz
scala-5fe2d8b109abf3ff3e2d82dd4f248200846795c3.tar.bz2
scala-5fe2d8b109abf3ff3e2d82dd4f248200846795c3.zip
Add the Ctrie concurrent map implementation.
Ctrie is a scalable concurrent map implementation that supports constant time lock-free lazy snapshots. Due to the well-known private volatile field problem, atomic reference updaters cannot be used efficiently in Scala yet. For this reason, 4 java files had to be included as well. None of these pollute the namespace, as most of the classes are private. Unit tests and a scalacheck check is also included.
Diffstat (limited to 'test')
-rw-r--r--test/files/run/ctries/DumbHash.scala14
-rw-r--r--test/files/run/ctries/Wrap.scala9
-rw-r--r--test/files/run/ctries/concmap.scala169
-rw-r--r--test/files/run/ctries/iterator.scala279
-rw-r--r--test/files/run/ctries/lnode.scala58
-rw-r--r--test/files/run/ctries/main.scala45
-rw-r--r--test/files/run/ctries/snapshot.scala267
-rw-r--r--test/files/scalacheck/Ctrie.scala199
8 files changed, 1040 insertions, 0 deletions
diff --git a/test/files/run/ctries/DumbHash.scala b/test/files/run/ctries/DumbHash.scala
new file mode 100644
index 0000000000..8ef325b67c
--- /dev/null
+++ b/test/files/run/ctries/DumbHash.scala
@@ -0,0 +1,14 @@
+
+
+
+
+
+
+class DumbHash(val i: Int) {
+ override def equals(other: Any) = other match {
+ case that: DumbHash => that.i == this.i
+ case _ => false
+ }
+ override def hashCode = i % 5
+ override def toString = "DH(%s)".format(i)
+}
diff --git a/test/files/run/ctries/Wrap.scala b/test/files/run/ctries/Wrap.scala
new file mode 100644
index 0000000000..7b645c1612
--- /dev/null
+++ b/test/files/run/ctries/Wrap.scala
@@ -0,0 +1,9 @@
+
+
+
+
+
+
+case class Wrap(i: Int) {
+ override def hashCode = i * 0x9e3775cd
+}
diff --git a/test/files/run/ctries/concmap.scala b/test/files/run/ctries/concmap.scala
new file mode 100644
index 0000000000..85a305ce5b
--- /dev/null
+++ b/test/files/run/ctries/concmap.scala
@@ -0,0 +1,169 @@
+
+
+
+import collection.mutable.Ctrie
+
+
+object ConcurrentMapSpec extends Spec {
+
+ val initsz = 500
+ val secondsz = 750
+
+ def test() {
+ "support put" in {
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until initsz) assert(ct.put(new Wrap(i), i) == None)
+ for (i <- 0 until initsz) assert(ct.put(new Wrap(i), -i) == Some(i))
+ }
+
+ "support put if absent" in {
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until initsz) ct.update(new Wrap(i), i)
+ for (i <- 0 until initsz) assert(ct.putIfAbsent(new Wrap(i), -i) == Some(i))
+ for (i <- 0 until initsz) assert(ct.putIfAbsent(new Wrap(i), -i) == Some(i))
+ for (i <- initsz until secondsz) assert(ct.putIfAbsent(new Wrap(i), -i) == None)
+ for (i <- initsz until secondsz) assert(ct.putIfAbsent(new Wrap(i), i) == Some(-i))
+ }
+
+ "support remove if mapped to a specific value" in {
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until initsz) ct.update(new Wrap(i), i)
+ for (i <- 0 until initsz) assert(ct.remove(new Wrap(i), -i - 1) == false)
+ for (i <- 0 until initsz) assert(ct.remove(new Wrap(i), i) == true)
+ for (i <- 0 until initsz) assert(ct.remove(new Wrap(i), i) == false)
+ }
+
+ "support replace if mapped to a specific value" in {
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until initsz) ct.update(new Wrap(i), i)
+ for (i <- 0 until initsz) assert(ct.replace(new Wrap(i), -i - 1, -i - 2) == false)
+ for (i <- 0 until initsz) assert(ct.replace(new Wrap(i), i, -i - 2) == true)
+ for (i <- 0 until initsz) assert(ct.replace(new Wrap(i), i, -i - 2) == false)
+ for (i <- initsz until secondsz) assert(ct.replace(new Wrap(i), i, 0) == false)
+ }
+
+ "support replace if present" in {
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until initsz) ct.update(new Wrap(i), i)
+ for (i <- 0 until initsz) assert(ct.replace(new Wrap(i), -i) == Some(i))
+ for (i <- 0 until initsz) assert(ct.replace(new Wrap(i), i) == Some(-i))
+ for (i <- initsz until secondsz) assert(ct.replace(new Wrap(i), i) == None)
+ }
+
+ def assertEqual(a: Any, b: Any) = {
+ if (a != b) println(a, b)
+ assert(a == b)
+ }
+
+ "support replace if mapped to a specific value, using several threads" in {
+ val ct = new Ctrie[Wrap, Int]
+ val sz = 55000
+ for (i <- 0 until sz) ct.update(new Wrap(i), i)
+
+ class Updater(index: Int, offs: Int) extends Thread {
+ override def run() {
+ var repeats = 0
+ for (i <- 0 until sz) {
+ val j = (offs + i) % sz
+ var k = Int.MaxValue
+ do {
+ if (k != Int.MaxValue) repeats += 1
+ k = ct.lookup(new Wrap(j))
+ } while (!ct.replace(new Wrap(j), k, -k))
+ }
+ //println("Thread %d repeats: %d".format(index, repeats))
+ }
+ }
+
+ val threads = for (i <- 0 until 16) yield new Updater(i, sz / 32 * i)
+ threads.foreach(_.start())
+ threads.foreach(_.join())
+
+ for (i <- 0 until sz) assertEqual(ct(new Wrap(i)), i)
+
+ val threads2 = for (i <- 0 until 15) yield new Updater(i, sz / 32 * i)
+ threads2.foreach(_.start())
+ threads2.foreach(_.join())
+
+ for (i <- 0 until sz) assertEqual(ct(new Wrap(i)), -i)
+ }
+
+ "support put if absent, several threads" in {
+ val ct = new Ctrie[Wrap, Int]
+ val sz = 110000
+
+ class Updater(offs: Int) extends Thread {
+ override def run() {
+ for (i <- 0 until sz) {
+ val j = (offs + i) % sz
+ ct.putIfAbsent(new Wrap(j), j)
+ assert(ct.lookup(new Wrap(j)) == j)
+ }
+ }
+ }
+
+ val threads = for (i <- 0 until 16) yield new Updater(sz / 32 * i)
+ threads.foreach(_.start())
+ threads.foreach(_.join())
+
+ for (i <- 0 until sz) assert(ct(new Wrap(i)) == i)
+ }
+
+ "support remove if mapped to a specific value, several threads" in {
+ val ct = new Ctrie[Wrap, Int]
+ val sz = 55000
+ for (i <- 0 until sz) ct.update(new Wrap(i), i)
+
+ class Remover(offs: Int) extends Thread {
+ override def run() {
+ for (i <- 0 until sz) {
+ val j = (offs + i) % sz
+ ct.remove(new Wrap(j), j)
+ assert(ct.get(new Wrap(j)) == None)
+ }
+ }
+ }
+
+ val threads = for (i <- 0 until 16) yield new Remover(sz / 32 * i)
+ threads.foreach(_.start())
+ threads.foreach(_.join())
+
+ for (i <- 0 until sz) assert(ct.get(new Wrap(i)) == None)
+ }
+
+ "have all or none of the elements depending on the oddity" in {
+ val ct = new Ctrie[Wrap, Int]
+ val sz = 65000
+ for (i <- 0 until sz) ct(new Wrap(i)) = i
+
+ class Modifier(index: Int, offs: Int) extends Thread {
+ override def run() {
+ for (j <- 0 until sz) {
+ val i = (offs + j) % sz
+ var success = false
+ do {
+ if (ct.contains(new Wrap(i))) {
+ success = ct.remove(new Wrap(i)) != None
+ } else {
+ success = ct.putIfAbsent(new Wrap(i), i) == None
+ }
+ } while (!success)
+ }
+ }
+ }
+
+ def modify(n: Int) = {
+ val threads = for (i <- 0 until n) yield new Modifier(i, sz / n * i)
+ threads.foreach(_.start())
+ threads.foreach(_.join())
+ }
+
+ modify(16)
+ for (i <- 0 until sz) assertEqual(ct.get(new Wrap(i)), Some(i))
+ modify(15)
+ for (i <- 0 until sz) assertEqual(ct.get(new Wrap(i)), None)
+ }
+
+ }
+
+}
diff --git a/test/files/run/ctries/iterator.scala b/test/files/run/ctries/iterator.scala
new file mode 100644
index 0000000000..1cef4f66ea
--- /dev/null
+++ b/test/files/run/ctries/iterator.scala
@@ -0,0 +1,279 @@
+
+
+
+
+import collection._
+import collection.mutable.Ctrie
+
+
+
+object IteratorSpec extends Spec {
+
+ def test() {
+ "work for an empty trie" in {
+ val ct = new Ctrie
+ val it = ct.iterator
+
+ it.hasNext shouldEqual (false)
+ evaluating { it.next() }.shouldProduce [NoSuchElementException]
+ }
+
+ def nonEmptyIteratorCheck(sz: Int) {
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct.put(new Wrap(i), i)
+
+ val it = ct.iterator
+ val tracker = mutable.Map[Wrap, Int]()
+ for (i <- 0 until sz) {
+ assert(it.hasNext == true)
+ tracker += it.next
+ }
+
+ it.hasNext shouldEqual (false)
+ evaluating { it.next() }.shouldProduce [NoSuchElementException]
+ tracker.size shouldEqual (sz)
+ tracker shouldEqual (ct)
+ }
+
+ "work for a 1 element trie" in {
+ nonEmptyIteratorCheck(1)
+ }
+
+ "work for a 2 element trie" in {
+ nonEmptyIteratorCheck(2)
+ }
+
+ "work for a 3 element trie" in {
+ nonEmptyIteratorCheck(3)
+ }
+
+ "work for a 5 element trie" in {
+ nonEmptyIteratorCheck(5)
+ }
+
+ "work for a 10 element trie" in {
+ nonEmptyIteratorCheck(10)
+ }
+
+ "work for a 20 element trie" in {
+ nonEmptyIteratorCheck(20)
+ }
+
+ "work for a 50 element trie" in {
+ nonEmptyIteratorCheck(50)
+ }
+
+ "work for a 100 element trie" in {
+ nonEmptyIteratorCheck(100)
+ }
+
+ "work for a 1k element trie" in {
+ nonEmptyIteratorCheck(1000)
+ }
+
+ "work for a 5k element trie" in {
+ nonEmptyIteratorCheck(5000)
+ }
+
+ "work for a 75k element trie" in {
+ nonEmptyIteratorCheck(75000)
+ }
+
+ "work for a 250k element trie" in {
+ nonEmptyIteratorCheck(500000)
+ }
+
+ def nonEmptyCollideCheck(sz: Int) {
+ val ct = new Ctrie[DumbHash, Int]
+ for (i <- 0 until sz) ct.put(new DumbHash(i), i)
+
+ val it = ct.iterator
+ val tracker = mutable.Map[DumbHash, Int]()
+ for (i <- 0 until sz) {
+ assert(it.hasNext == true)
+ tracker += it.next
+ }
+
+ it.hasNext shouldEqual (false)
+ evaluating { it.next() }.shouldProduce [NoSuchElementException]
+ tracker.size shouldEqual (sz)
+ tracker shouldEqual (ct)
+ }
+
+ "work for colliding hashcodes, 2 element trie" in {
+ nonEmptyCollideCheck(2)
+ }
+
+ "work for colliding hashcodes, 3 element trie" in {
+ nonEmptyCollideCheck(3)
+ }
+
+ "work for colliding hashcodes, 5 element trie" in {
+ nonEmptyCollideCheck(5)
+ }
+
+ "work for colliding hashcodes, 10 element trie" in {
+ nonEmptyCollideCheck(10)
+ }
+
+ "work for colliding hashcodes, 100 element trie" in {
+ nonEmptyCollideCheck(100)
+ }
+
+ "work for colliding hashcodes, 500 element trie" in {
+ nonEmptyCollideCheck(500)
+ }
+
+ "work for colliding hashcodes, 5k element trie" in {
+ nonEmptyCollideCheck(5000)
+ }
+
+ def assertEqual(a: Map[Wrap, Int], b: Map[Wrap, Int]) {
+ if (a != b) {
+ println(a.size + " vs " + b.size)
+ // println(a)
+ // println(b)
+ // println(a.toSeq.sortBy((x: (Wrap, Int)) => x._1.i))
+ // println(b.toSeq.sortBy((x: (Wrap, Int)) => x._1.i))
+ }
+ assert(a == b)
+ }
+
+ "be consistent when taken with concurrent modifications" in {
+ val sz = 25000
+ val W = 25
+ val S = 10
+ val checks = 5
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct.put(new Wrap(i), i)
+
+ class Modifier extends Thread {
+ override def run() {
+ for (i <- 0 until sz) ct.putIfAbsent(new Wrap(i), i) match {
+ case Some(_) => ct.remove(new Wrap(i))
+ case None =>
+ }
+ }
+ }
+
+ def consistentIteration(ct: Ctrie[Wrap, Int], checks: Int) {
+ class Iter extends Thread {
+ override def run() {
+ val snap = ct.readOnlySnapshot()
+ val initial = mutable.Map[Wrap, Int]()
+ for (kv <- snap) initial += kv
+
+ for (i <- 0 until checks) {
+ assertEqual(snap.iterator.toMap, initial)
+ }
+ }
+ }
+
+ val iter = new Iter
+ iter.start()
+ iter.join()
+ }
+
+ val threads = for (_ <- 0 until W) yield new Modifier
+ threads.foreach(_.start())
+ for (_ <- 0 until S) consistentIteration(ct, checks)
+ threads.foreach(_.join())
+ }
+
+ "be consistent with a concurrent removal with a well defined order" in {
+ val sz = 150000
+ val sgroupsize = 40
+ val sgroupnum = 20
+ val removerslowdown = 50
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct.put(new Wrap(i), i)
+
+ class Remover extends Thread {
+ override def run() {
+ for (i <- 0 until sz) {
+ assert(ct.remove(new Wrap(i)) == Some(i))
+ for (i <- 0 until removerslowdown) ct.get(new Wrap(i)) // slow down, mate
+ }
+ //println("done removing")
+ }
+ }
+
+ def consistentIteration(it: Iterator[(Wrap, Int)]) = {
+ class Iter extends Thread {
+ override def run() {
+ val elems = it.toSeq
+ if (elems.nonEmpty) {
+ val minelem = elems.minBy((x: (Wrap, Int)) => x._1.i)._1.i
+ assert(elems.forall(_._1.i >= minelem))
+ }
+ }
+ }
+ new Iter
+ }
+
+ val remover = new Remover
+ remover.start()
+ for (_ <- 0 until sgroupnum) {
+ val iters = for (_ <- 0 until sgroupsize) yield consistentIteration(ct.iterator)
+ iters.foreach(_.start())
+ iters.foreach(_.join())
+ }
+ //println("done with iterators")
+ remover.join()
+ }
+
+ "be consistent with a concurrent insertion with a well defined order" in {
+ val sz = 150000
+ val sgroupsize = 30
+ val sgroupnum = 30
+ val inserterslowdown = 50
+ val ct = new Ctrie[Wrap, Int]
+
+ class Inserter extends Thread {
+ override def run() {
+ for (i <- 0 until sz) {
+ assert(ct.put(new Wrap(i), i) == None)
+ for (i <- 0 until inserterslowdown) ct.get(new Wrap(i)) // slow down, mate
+ }
+ //println("done inserting")
+ }
+ }
+
+ def consistentIteration(it: Iterator[(Wrap, Int)]) = {
+ class Iter extends Thread {
+ override def run() {
+ val elems = it.toSeq
+ if (elems.nonEmpty) {
+ val maxelem = elems.maxBy((x: (Wrap, Int)) => x._1.i)._1.i
+ assert(elems.forall(_._1.i <= maxelem))
+ }
+ }
+ }
+ new Iter
+ }
+
+ val inserter = new Inserter
+ inserter.start()
+ for (_ <- 0 until sgroupnum) {
+ val iters = for (_ <- 0 until sgroupsize) yield consistentIteration(ct.iterator)
+ iters.foreach(_.start())
+ iters.foreach(_.join())
+ }
+ //println("done with iterators")
+ inserter.join()
+ }
+
+ "work on a yet unevaluated snapshot" in {
+ val sz = 50000
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct.update(new Wrap(i), i)
+
+ val snap = ct.snapshot()
+ val it = snap.iterator
+
+ while (it.hasNext) it.next()
+ }
+
+ }
+
+}
diff --git a/test/files/run/ctries/lnode.scala b/test/files/run/ctries/lnode.scala
new file mode 100644
index 0000000000..28da4cc62f
--- /dev/null
+++ b/test/files/run/ctries/lnode.scala
@@ -0,0 +1,58 @@
+
+
+
+import collection.mutable.Ctrie
+
+
+object LNodeSpec extends Spec {
+
+ val initsz = 1500
+ val secondsz = 1750
+
+ def test() {
+ "accept elements with the same hash codes" in {
+ val ct = new Ctrie[DumbHash, Int]
+ for (i <- 0 until initsz) ct.update(new DumbHash(i), i)
+ }
+
+ "lookup elements with the same hash codes" in {
+ val ct = new Ctrie[DumbHash, Int]
+ for (i <- 0 until initsz) ct.update(new DumbHash(i), i)
+ for (i <- 0 until initsz) assert(ct.get(new DumbHash(i)) == Some(i))
+ for (i <- initsz until secondsz) assert(ct.get(new DumbHash(i)) == None)
+ }
+
+ "remove elements with the same hash codes" in {
+ val ct = new Ctrie[DumbHash, Int]
+ for (i <- 0 until initsz) ct.update(new DumbHash(i), i)
+ for (i <- 0 until initsz) assert(ct.remove(new DumbHash(i)) == Some(i))
+ for (i <- 0 until initsz) assert(ct.get(new DumbHash(i)) == None)
+ }
+
+ "put elements with the same hash codes if absent" in {
+ val ct = new Ctrie[DumbHash, Int]
+ for (i <- 0 until initsz) ct.put(new DumbHash(i), i)
+ for (i <- 0 until initsz) assert(ct.lookup(new DumbHash(i)) == i)
+ for (i <- 0 until initsz) assert(ct.putIfAbsent(new DumbHash(i), i) == Some(i))
+ for (i <- initsz until secondsz) assert(ct.putIfAbsent(new DumbHash(i), i) == None)
+ for (i <- initsz until secondsz) assert(ct.lookup(new DumbHash(i)) == i)
+ }
+
+ "replace elements with the same hash codes" in {
+ val ct = new Ctrie[DumbHash, Int]
+ for (i <- 0 until initsz) assert(ct.put(new DumbHash(i), i) == None)
+ for (i <- 0 until initsz) assert(ct.lookup(new DumbHash(i)) == i)
+ for (i <- 0 until initsz) assert(ct.replace(new DumbHash(i), -i) == Some(i))
+ for (i <- 0 until initsz) assert(ct.lookup(new DumbHash(i)) == -i)
+ for (i <- 0 until initsz) assert(ct.replace(new DumbHash(i), -i, i) == true)
+ }
+
+ "remove elements with the same hash codes if mapped to a specific value" in {
+ val ct = new Ctrie[DumbHash, Int]
+ for (i <- 0 until initsz) assert(ct.put(new DumbHash(i), i) == None)
+ for (i <- 0 until initsz) assert(ct.remove(new DumbHash(i), i) == true)
+ }
+
+ }
+
+}
diff --git a/test/files/run/ctries/main.scala b/test/files/run/ctries/main.scala
new file mode 100644
index 0000000000..8db7fcef54
--- /dev/null
+++ b/test/files/run/ctries/main.scala
@@ -0,0 +1,45 @@
+
+
+
+
+
+
+
+object Test {
+
+ def main(args: Array[String]) {
+ ConcurrentMapSpec.test()
+ IteratorSpec.test()
+ LNodeSpec.test()
+ SnapshotSpec.test()
+ }
+
+}
+
+
+trait Spec {
+
+ implicit def str2ops(s: String) = new {
+ def in[U](body: =>U) {
+ // just execute body
+ body
+ }
+ }
+
+ implicit def any2ops(a: Any) = new {
+ def shouldEqual(other: Any) = assert(a == other)
+ }
+
+ def evaluating[U](body: =>U) = new {
+ def shouldProduce[T <: Throwable: ClassManifest]() = {
+ var produced = false
+ try body
+ catch {
+ case e => if (e.getClass == implicitly[ClassManifest[T]].erasure) produced = true
+ } finally {
+ assert(produced, "Did not produce exception of type: " + implicitly[ClassManifest[T]])
+ }
+ }
+ }
+
+}
diff --git a/test/files/run/ctries/snapshot.scala b/test/files/run/ctries/snapshot.scala
new file mode 100644
index 0000000000..69073d3f06
--- /dev/null
+++ b/test/files/run/ctries/snapshot.scala
@@ -0,0 +1,267 @@
+
+
+
+
+import collection._
+import collection.mutable.Ctrie
+
+
+
+object SnapshotSpec extends Spec {
+
+ def test() {
+ "support snapshots" in {
+ val ctn = new Ctrie
+ ctn.snapshot()
+ ctn.readOnlySnapshot()
+
+ val ct = new Ctrie[Int, Int]
+ for (i <- 0 until 100) ct.put(i, i)
+ ct.snapshot()
+ ct.readOnlySnapshot()
+ }
+
+ "empty 2 quiescent snapshots in isolation" in {
+ val sz = 4000
+
+ class Worker(trie: Ctrie[Wrap, Int]) extends Thread {
+ override def run() {
+ for (i <- 0 until sz) {
+ assert(trie.remove(new Wrap(i)) == Some(i))
+ for (j <- 0 until sz)
+ if (j <= i) assert(trie.get(new Wrap(j)) == None)
+ else assert(trie.get(new Wrap(j)) == Some(j))
+ }
+ }
+ }
+
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct.put(new Wrap(i), i)
+ val snapt = ct.snapshot()
+
+ val original = new Worker(ct)
+ val snapshot = new Worker(snapt)
+ original.start()
+ snapshot.start()
+ original.join()
+ snapshot.join()
+
+ for (i <- 0 until sz) {
+ assert(ct.get(new Wrap(i)) == None)
+ assert(snapt.get(new Wrap(i)) == None)
+ }
+ }
+
+ def consistentReadOnly(name: String, readonly: Map[Wrap, Int], sz: Int, N: Int) {
+ @volatile var e: Exception = null
+
+ // reads possible entries once and stores them
+ // then reads all these N more times to check if the
+ // state stayed the same
+ class Reader(trie: Map[Wrap, Int]) extends Thread {
+ setName("Reader " + name)
+
+ override def run() =
+ try check()
+ catch {
+ case ex: Exception => e = ex
+ }
+
+ def check() {
+ val initial = mutable.Map[Wrap, Int]()
+ for (i <- 0 until sz) trie.get(new Wrap(i)) match {
+ case Some(i) => initial.put(new Wrap(i), i)
+ case None => // do nothing
+ }
+
+ for (k <- 0 until N) {
+ for (i <- 0 until sz) {
+ val tres = trie.get(new Wrap(i))
+ val ires = initial.get(new Wrap(i))
+ if (tres != ires) println(i, "initially: " + ires, "traversal %d: %s".format(k, tres))
+ assert(tres == ires)
+ }
+ }
+ }
+ }
+
+ val reader = new Reader(readonly)
+ reader.start()
+ reader.join()
+
+ if (e ne null) {
+ e.printStackTrace()
+ throw e
+ }
+ }
+
+ // traverses the trie `rep` times and modifies each entry
+ class Modifier(trie: Ctrie[Wrap, Int], index: Int, rep: Int, sz: Int) extends Thread {
+ setName("Modifier %d".format(index))
+
+ override def run() {
+ for (k <- 0 until rep) {
+ for (i <- 0 until sz) trie.putIfAbsent(new Wrap(i), i) match {
+ case Some(_) => trie.remove(new Wrap(i))
+ case None => // do nothing
+ }
+ }
+ }
+ }
+
+ // removes all the elements from the trie
+ class Remover(trie: Ctrie[Wrap, Int], index: Int, totremovers: Int, sz: Int) extends Thread {
+ setName("Remover %d".format(index))
+
+ override def run() {
+ for (i <- 0 until sz) trie.remove(new Wrap((i + sz / totremovers * index) % sz))
+ }
+ }
+
+ "have a consistent quiescent read-only snapshot" in {
+ val sz = 10000
+ val N = 100
+ val W = 10
+
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct(new Wrap(i)) = i
+ val readonly = ct.readOnlySnapshot()
+ val threads = for (i <- 0 until W) yield new Modifier(ct, i, N, sz)
+
+ threads.foreach(_.start())
+ consistentReadOnly("qm", readonly, sz, N)
+ threads.foreach(_.join())
+ }
+
+ // now, we check non-quiescent snapshots, as these permit situations
+ // where a thread is caught in the middle of the update when a snapshot is taken
+
+ "have a consistent non-quiescent read-only snapshot, concurrent with removes only" in {
+ val sz = 1250
+ val W = 100
+ val S = 5000
+
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct(new Wrap(i)) = i
+ val threads = for (i <- 0 until W) yield new Remover(ct, i, W, sz)
+
+ threads.foreach(_.start())
+ for (i <- 0 until S) consistentReadOnly("non-qr", ct.readOnlySnapshot(), sz, 5)
+ threads.foreach(_.join())
+ }
+
+ "have a consistent non-quiescent read-only snapshot, concurrent with modifications" in {
+ val sz = 1000
+ val N = 7000
+ val W = 10
+ val S = 7000
+
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct(new Wrap(i)) = i
+ val threads = for (i <- 0 until W) yield new Modifier(ct, i, N, sz)
+
+ threads.foreach(_.start())
+ for (i <- 0 until S) consistentReadOnly("non-qm", ct.readOnlySnapshot(), sz, 5)
+ threads.foreach(_.join())
+ }
+
+ def consistentNonReadOnly(name: String, trie: Ctrie[Wrap, Int], sz: Int, N: Int) {
+ @volatile var e: Exception = null
+
+ // reads possible entries once and stores them
+ // then reads all these N more times to check if the
+ // state stayed the same
+ class Worker extends Thread {
+ setName("Worker " + name)
+
+ override def run() =
+ try check()
+ catch {
+ case ex: Exception => e = ex
+ }
+
+ def check() {
+ val initial = mutable.Map[Wrap, Int]()
+ for (i <- 0 until sz) trie.get(new Wrap(i)) match {
+ case Some(i) => initial.put(new Wrap(i), i)
+ case None => // do nothing
+ }
+
+ for (k <- 0 until N) {
+ // modify
+ for ((key, value) <- initial) {
+ val oldv = if (k % 2 == 0) value else -value
+ val newv = -oldv
+ trie.replace(key, oldv, newv)
+ }
+
+ // check
+ for (i <- 0 until sz) if (initial.contains(new Wrap(i))) {
+ val expected = if (k % 2 == 0) -i else i
+ //println(trie.get(new Wrap(i)))
+ assert(trie.get(new Wrap(i)) == Some(expected))
+ } else {
+ assert(trie.get(new Wrap(i)) == None)
+ }
+ }
+ }
+ }
+
+ val worker = new Worker
+ worker.start()
+ worker.join()
+
+ if (e ne null) {
+ e.printStackTrace()
+ throw e
+ }
+ }
+
+ "have a consistent non-quiescent snapshot, concurrent with modifications" in {
+ val sz = 9000
+ val N = 1000
+ val W = 10
+ val S = 400
+
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct(new Wrap(i)) = i
+ val threads = for (i <- 0 until W) yield new Modifier(ct, i, N, sz)
+
+ threads.foreach(_.start())
+ for (i <- 0 until S) {
+ consistentReadOnly("non-qm", ct.snapshot(), sz, 5)
+ consistentNonReadOnly("non-qsnap", ct.snapshot(), sz, 5)
+ }
+ threads.foreach(_.join())
+ }
+
+ "work when many concurrent snapshots are taken, concurrent with modifications" in {
+ val sz = 12000
+ val W = 10
+ val S = 10
+ val modifytimes = 1200
+ val snaptimes = 600
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct(new Wrap(i)) = i
+
+ class Snapshooter extends Thread {
+ setName("Snapshooter")
+ override def run() {
+ for (k <- 0 until snaptimes) {
+ val snap = ct.snapshot()
+ for (i <- 0 until sz) snap.remove(new Wrap(i))
+ for (i <- 0 until sz) assert(!snap.contains(new Wrap(i)))
+ }
+ }
+ }
+
+ val mods = for (i <- 0 until W) yield new Modifier(ct, i, modifytimes, sz)
+ val shooters = for (i <- 0 until S) yield new Snapshooter
+ val threads = mods ++ shooters
+ threads.foreach(_.start())
+ threads.foreach(_.join())
+ }
+
+ }
+
+}
diff --git a/test/files/scalacheck/Ctrie.scala b/test/files/scalacheck/Ctrie.scala
new file mode 100644
index 0000000000..2950937278
--- /dev/null
+++ b/test/files/scalacheck/Ctrie.scala
@@ -0,0 +1,199 @@
+
+
+
+import org.scalacheck._
+import Prop._
+import org.scalacheck.Gen._
+import collection._
+import collection.mutable.Ctrie
+
+
+
+case class Wrap(i: Int) {
+ override def hashCode = i // * 0x9e3775cd
+}
+
+
+/** A check mainly oriented towards checking snapshot correctness.
+ */
+object Test extends Properties("Ctrie") {
+
+ /* generators */
+
+ val sizes = choose(0, 200000)
+
+ val threadCounts = choose(2, 16)
+
+ val threadCountsAndSizes = for {
+ p <- threadCounts
+ sz <- sizes
+ } yield (p, sz);
+
+
+ /* helpers */
+
+ def inParallel[T](totalThreads: Int)(body: Int => T): Seq[T] = {
+ val threads = for (idx <- 0 until totalThreads) yield new Thread {
+ setName("ParThread-" + idx)
+ private var res: T = _
+ override def run() {
+ res = body(idx)
+ }
+ def result = {
+ this.join()
+ res
+ }
+ }
+
+ threads foreach (_.start())
+ threads map (_.result)
+ }
+
+ def spawn[T](body: =>T): { def get: T } = {
+ val t = new Thread {
+ setName("SpawnThread")
+ private var res: T = _
+ override def run() {
+ res = body
+ }
+ def result = res
+ }
+ t.start()
+ new {
+ def get: T = {
+ t.join()
+ t.result
+ }
+ }
+ }
+
+ def elementRange(threadIdx: Int, totalThreads: Int, totalElems: Int): Range = {
+ val sz = totalElems
+ val idx = threadIdx
+ val p = totalThreads
+ val start = (sz / p) * idx + math.min(idx, sz % p)
+ val elems = (sz / p) + (if (idx < sz % p) 1 else 0)
+ val end = start + elems
+ (start until end)
+ }
+
+ def hasGrown[K, V](last: Map[K, V], current: Map[K, V]) = {
+ (last.size <= current.size) && {
+ last forall {
+ case (k, v) => current.get(k) == Some(v)
+ }
+ }
+ }
+
+ object err {
+ var buffer = new StringBuilder
+ def println(a: AnyRef) = buffer.append(a.toString).append("\n")
+ def clear() = buffer.clear()
+ def flush() = {
+ Console.out.println(buffer)
+ clear()
+ }
+ }
+
+
+ /* properties */
+
+ property("concurrent growing snapshots") = forAll(threadCounts, sizes) {
+ (numThreads, numElems) =>
+ val p = 3 //numThreads
+ val sz = 102 //numElems
+ val ct = new Ctrie[Wrap, Int]
+
+ // checker
+ val checker = spawn {
+ def check(last: Map[Wrap, Int], iterationsLeft: Int): Boolean = {
+ val current = ct.readOnlySnapshot()
+ if (!hasGrown(last, current)) false
+ else if (current.size >= sz) true
+ else if (iterationsLeft < 0) false
+ else check(current, iterationsLeft - 1)
+ }
+ check(ct.readOnlySnapshot(), 500)
+ }
+
+ // fillers
+ inParallel(p) {
+ idx =>
+ elementRange(idx, p, sz) foreach (i => ct.update(Wrap(i), i))
+ }
+
+ // wait for checker to finish
+ val growing = true//checker.get
+
+ val ok = growing && ((0 until sz) forall {
+ case i => ct.get(Wrap(i)) == Some(i)
+ })
+
+ ok
+ }
+
+ property("update") = forAll(sizes) {
+ (n: Int) =>
+ val ct = new Ctrie[Int, Int]
+ for (i <- 0 until n) ct(i) = i
+ (0 until n) forall {
+ case i => ct(i) == i
+ }
+ }
+
+ property("concurrent update") = forAll(threadCountsAndSizes) {
+ case (p, sz) =>
+ val ct = new Ctrie[Wrap, Int]
+
+ inParallel(p) {
+ idx =>
+ for (i <- elementRange(idx, p, sz)) ct(Wrap(i)) = i
+ }
+
+ (0 until sz) forall {
+ case i => ct(Wrap(i)) == i
+ }
+ }
+
+
+ property("concurrent remove") = forAll(threadCounts, sizes) {
+ (p, sz) =>
+ val ct = new Ctrie[Wrap, Int]
+ for (i <- 0 until sz) ct(Wrap(i)) = i
+
+ inParallel(p) {
+ idx =>
+ for (i <- elementRange(idx, p, sz)) ct.remove(Wrap(i))
+ }
+
+ (0 until sz) forall {
+ case i => ct.get(Wrap(i)) == None
+ }
+ }
+
+
+ property("concurrent putIfAbsent") = forAll(threadCounts, sizes) {
+ (p, sz) =>
+ val ct = new Ctrie[Wrap, Int]
+
+ val results = inParallel(p) {
+ idx =>
+ elementRange(idx, p, sz) find (i => ct.putIfAbsent(Wrap(i), i) != None)
+ }
+
+ (results forall (_ == None)) && ((0 until sz) forall {
+ case i => ct.get(Wrap(i)) == Some(i)
+ })
+ }
+
+}
+
+
+
+
+
+
+
+
+
+