summaryrefslogtreecommitdiff
path: root/src/library/scala/collection/concurrent
diff options
context:
space:
mode:
authorAleksandar Prokopec <axel22@gmail.com>2012-05-30 14:08:41 +0200
committerAleksandar Prokopec <axel22@gmail.com>2012-06-01 17:24:47 +0200
commitd38ad5e9877011e635b6c2cb6c4f3fcb31dfb7d2 (patch)
treeda7ca5ffe16846f75b27b48dc6f932b08b471058 /src/library/scala/collection/concurrent
parent8d38079ab457e77b3ba57076f3766a158c1eb030 (diff)
downloadscala-d38ad5e9877011e635b6c2cb6c4f3fcb31dfb7d2.tar.gz
scala-d38ad5e9877011e635b6c2cb6c4f3fcb31dfb7d2.tar.bz2
scala-d38ad5e9877011e635b6c2cb6c4f3fcb31dfb7d2.zip
Add Hashing and Equality typeclasses.
Modify TrieMap to use hashing and equality. Modify serialization in TrieMap appropriately.
Diffstat (limited to 'src/library/scala/collection/concurrent')
-rw-r--r--src/library/scala/collection/concurrent/TrieMap.scala72
1 files changed, 45 insertions, 27 deletions
diff --git a/src/library/scala/collection/concurrent/TrieMap.scala b/src/library/scala/collection/concurrent/TrieMap.scala
index 2a908aebb1..b38cb4e376 100644
--- a/src/library/scala/collection/concurrent/TrieMap.scala
+++ b/src/library/scala/collection/concurrent/TrieMap.scala
@@ -14,6 +14,7 @@ package concurrent
import java.util.concurrent.atomic._
import collection.immutable.{ ListMap => ImmutableListMap }
import collection.parallel.mutable.ParTrieMap
+import math.Hashing
import generic._
import annotation.tailrec
import annotation.switch
@@ -80,6 +81,9 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
} else false
}
+ @inline
+ private def equal(k1: K, k2: K, ct: TrieMap[K, V]) = ct.equality.areEqual(k1, k2)
+
@inline private def inode(cn: MainNode[K, V]) = {
val nin = new INode[K, V](gen)
nin.WRITE(cn)
@@ -117,7 +121,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
else false
}
case sn: SNode[K, V] =>
- if (sn.hc == hc && sn.k == k) GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)
+ if (sn.hc == hc && equal(sn.k, k, ct)) GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)
else {
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
val nn = rn.updatedAt(pos, inode(CNode.dual(sn, sn.hc, new SNode(k, v, hc), hc, lev + 5, gen)), gen)
@@ -164,7 +168,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
}
case sn: SNode[K, V] => cond match {
case null =>
- if (sn.hc == hc && sn.k == k) {
+ if (sn.hc == hc && equal(sn.k, k, ct)) {
if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null
} else {
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
@@ -173,7 +177,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
else null
}
case INode.KEY_ABSENT =>
- if (sn.hc == hc && sn.k == k) Some(sn.v)
+ if (sn.hc == hc && equal(sn.k, k, ct)) Some(sn.v)
else {
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
val nn = rn.updatedAt(pos, inode(CNode.dual(sn, sn.hc, new SNode(k, v, hc), hc, lev + 5, gen)), gen)
@@ -181,11 +185,11 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
else null
}
case INode.KEY_PRESENT =>
- if (sn.hc == hc && sn.k == k) {
+ if (sn.hc == hc && equal(sn.k, k, ct)) {
if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null
} else None
case otherv: V =>
- if (sn.hc == hc && sn.k == k && sn.v == otherv) {
+ if (sn.hc == hc && equal(sn.k, k, ct) && sn.v == otherv) {
if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null
} else None
}
@@ -253,7 +257,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
else return RESTART // used to be throw RestartException
}
case sn: SNode[K, V] => // 2) singleton node
- if (sn.hc == hc && sn.k == k) sn.v.asInstanceOf[AnyRef]
+ if (sn.hc == hc && equal(sn.k, k, ct)) sn.v.asInstanceOf[AnyRef]
else null
}
}
@@ -296,7 +300,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
else null
}
case sn: SNode[K, V] =>
- if (sn.hc == hc && sn.k == k && (v == null || sn.v == v)) {
+ if (sn.hc == hc && equal(sn.k, k, ct) && (v == null || sn.v == v)) {
val ncn = cn.removedAt(pos, flag, gen).toContracted(lev)
if (GCAS(cn, ncn, ct)) Some(sn.v) else null
} else None
@@ -341,11 +345,11 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
case ln: LNode[K, V] =>
if (v == null) {
val optv = ln.get(k)
- val nn = ln.removed(k)
+ val nn = ln.removed(k, ct)
if (GCAS(ln, nn, ct)) optv else null
} else ln.get(k) match {
case optv @ Some(v0) if v0 == v =>
- val nn = ln.removed(k)
+ val nn = ln.removed(k, ct)
if (GCAS(ln, nn, ct)) optv else null
case _ => None
}
@@ -433,12 +437,12 @@ extends MainNode[K, V] {
def this(k: K, v: V) = this(ImmutableListMap(k -> v))
def this(k1: K, v1: V, k2: K, v2: V) = this(ImmutableListMap(k1 -> v1, k2 -> v2))
def inserted(k: K, v: V) = new LNode(listmap + ((k, v)))
- def removed(k: K): MainNode[K, V] = {
+ def removed(k: K, ct: TrieMap[K, V]): MainNode[K, V] = {
val updmap = listmap - k
if (updmap.size > 1) new LNode(updmap)
else {
val (k, v) = updmap.iterator.next
- new TNode(k, v, TrieMap.computeHash(k)) // create it tombed so that it gets compressed on subsequent accesses
+ new TNode(k, v, ct.computeHash(k)) // create it tombed so that it gets compressed on subsequent accesses
}
}
def get(k: K) = listmap.get(k)
@@ -627,25 +631,34 @@ private[concurrent] case class RDCSS_Descriptor[K, V](old: INode[K, V], expected
* @since 2.10
*/
@SerialVersionUID(0L - 6402774413839597105L)
-final class TrieMap[K, V] private (r: AnyRef, rtupd: AtomicReferenceFieldUpdater[TrieMap[K, V], AnyRef])
+final class TrieMap[K, V] private (r: AnyRef, rtupd: AtomicReferenceFieldUpdater[TrieMap[K, V], AnyRef], hashf: Hashing[K], ef: Equality[K])
extends scala.collection.concurrent.Map[K, V]
with scala.collection.mutable.MapLike[K, V, TrieMap[K, V]]
with CustomParallelizable[(K, V), ParTrieMap[K, V]]
with Serializable
{
- import TrieMap.computeHash
-
+ private var hashingobj = hashf
+ private var equalityobj = ef
private var rootupdater = rtupd
+ def hashing = hashingobj
+ def equality = equalityobj
@volatile var root = r
-
- def this() = this(
+
+ def this(hashf: Hashing[K], ef: Equality[K]) = this(
INode.newRootNode,
- AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root")
+ AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root"),
+ hashf,
+ ef
)
-
+
+ def this() = this(Hashing.defaultHashing, Equality.defaultEquality)
+
/* internal methods */
private def writeObject(out: java.io.ObjectOutputStream) {
+ out.writeObject(hashf)
+ out.writeObject(ef)
+
val it = iterator
while (it.hasNext) {
val (k, v) = it.next()
@@ -659,6 +672,9 @@ extends scala.collection.concurrent.Map[K, V]
root = INode.newRootNode
rootupdater = AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root")
+ hashingobj = in.readObject().asInstanceOf[Hashing[K]]
+ equalityobj = in.readObject().asInstanceOf[Equality[K]]
+
var obj: AnyRef = null
do {
obj = in.readObject()
@@ -780,7 +796,7 @@ extends scala.collection.concurrent.Map[K, V]
@tailrec final def snapshot(): TrieMap[K, V] = {
val r = RDCSS_READ_ROOT()
val expmain = r.gcasRead(this)
- if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r.copyToGen(new Gen, this), rootupdater)
+ if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r.copyToGen(new Gen, this), rootupdater, hashing, equality)
else snapshot()
}
@@ -799,7 +815,7 @@ extends scala.collection.concurrent.Map[K, V]
@tailrec final def readOnlySnapshot(): collection.Map[K, V] = {
val r = RDCSS_READ_ROOT()
val expmain = r.gcasRead(this)
- if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r, null)
+ if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r, null, hashing, equality)
else readOnlySnapshot()
}
@@ -807,7 +823,16 @@ extends scala.collection.concurrent.Map[K, V]
val r = RDCSS_READ_ROOT()
if (!RDCSS_ROOT(r, r.gcasRead(this), INode.newRootNode[K, V])) clear()
}
+
+ @inline private def mangle(hc: Int): Int = {
+ var hcode = hc * 0x9e3775cd
+ hcode = java.lang.Integer.reverseBytes(hcode)
+ hcode * 0x9e3775cd
+ }
+ @inline
+ def computeHash(k: K) = mangle(hashingobj.hashCode(k))
+
final def lookup(k: K): V = {
val hc = computeHash(k)
lookuphc(k, hc).asInstanceOf[V]
@@ -895,13 +920,6 @@ object TrieMap extends MutableMapFactory[TrieMap] {
def empty[K, V]: TrieMap[K, V] = new TrieMap[K, V]
- @inline final def computeHash[K](k: K): Int = {
- var hcode = k.hashCode
- hcode = hcode * 0x9e3775cd
- hcode = java.lang.Integer.reverseBytes(hcode)
- hcode * 0x9e3775cd
- }
-
}