diff options
Diffstat (limited to 'src/library/scala/collection/concurrent/TrieMap.scala')
-rw-r--r-- | src/library/scala/collection/concurrent/TrieMap.scala | 76 |
1 files changed, 48 insertions, 28 deletions
diff --git a/src/library/scala/collection/concurrent/TrieMap.scala b/src/library/scala/collection/concurrent/TrieMap.scala index 2a908aebb1..f944d20bdb 100644 --- a/src/library/scala/collection/concurrent/TrieMap.scala +++ b/src/library/scala/collection/concurrent/TrieMap.scala @@ -14,6 +14,7 @@ package concurrent import java.util.concurrent.atomic._ import collection.immutable.{ ListMap => ImmutableListMap } import collection.parallel.mutable.ParTrieMap +import util.hashing.Hashing import generic._ import annotation.tailrec import annotation.switch @@ -80,6 +81,9 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends } else false } + @inline + private def equal(k1: K, k2: K, ct: TrieMap[K, V]) = ct.equality.equiv(k1, k2) + @inline private def inode(cn: MainNode[K, V]) = { val nin = new INode[K, V](gen) nin.WRITE(cn) @@ -117,7 +121,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends else false } case sn: SNode[K, V] => - if (sn.hc == hc && sn.k == k) GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct) + if (sn.hc == hc && equal(sn.k, k, ct)) GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct) else { val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct) val nn = rn.updatedAt(pos, inode(CNode.dual(sn, sn.hc, new SNode(k, v, hc), hc, lev + 5, gen)), gen) @@ -164,7 +168,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends } case sn: SNode[K, V] => cond match { case null => - if (sn.hc == hc && sn.k == k) { + if (sn.hc == hc && equal(sn.k, k, ct)) { if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null } else { val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct) @@ -173,7 +177,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends else null } case INode.KEY_ABSENT => - if (sn.hc == hc && sn.k == k) Some(sn.v) + if (sn.hc == hc && equal(sn.k, k, ct)) Some(sn.v) else { val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct) val nn = rn.updatedAt(pos, inode(CNode.dual(sn, sn.hc, new SNode(k, v, hc), hc, lev + 5, gen)), gen) @@ -181,11 +185,11 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends else null } case INode.KEY_PRESENT => - if (sn.hc == hc && sn.k == k) { + if (sn.hc == hc && equal(sn.k, k, ct)) { if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null } else None case otherv: V => - if (sn.hc == hc && sn.k == k && sn.v == otherv) { + if (sn.hc == hc && equal(sn.k, k, ct) && sn.v == otherv) { if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null } else None } @@ -253,7 +257,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends else return RESTART // used to be throw RestartException } case sn: SNode[K, V] => // 2) singleton node - if (sn.hc == hc && sn.k == k) sn.v.asInstanceOf[AnyRef] + if (sn.hc == hc && equal(sn.k, k, ct)) sn.v.asInstanceOf[AnyRef] else null } } @@ -296,7 +300,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends else null } case sn: SNode[K, V] => - if (sn.hc == hc && sn.k == k && (v == null || sn.v == v)) { + if (sn.hc == hc && equal(sn.k, k, ct) && (v == null || sn.v == v)) { val ncn = cn.removedAt(pos, flag, gen).toContracted(lev) if (GCAS(cn, ncn, ct)) Some(sn.v) else null } else None @@ -341,11 +345,11 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends case ln: LNode[K, V] => if (v == null) { val optv = ln.get(k) - val nn = ln.removed(k) + val nn = ln.removed(k, ct) if (GCAS(ln, nn, ct)) optv else null } else ln.get(k) match { case optv @ Some(v0) if v0 == v => - val nn = ln.removed(k) + val nn = ln.removed(k, ct) if (GCAS(ln, nn, ct)) optv else null case _ => None } @@ -433,12 +437,12 @@ extends MainNode[K, V] { def this(k: K, v: V) = this(ImmutableListMap(k -> v)) def this(k1: K, v1: V, k2: K, v2: V) = this(ImmutableListMap(k1 -> v1, k2 -> v2)) def inserted(k: K, v: V) = new LNode(listmap + ((k, v))) - def removed(k: K): MainNode[K, V] = { + def removed(k: K, ct: TrieMap[K, V]): MainNode[K, V] = { val updmap = listmap - k if (updmap.size > 1) new LNode(updmap) else { val (k, v) = updmap.iterator.next - new TNode(k, v, TrieMap.computeHash(k)) // create it tombed so that it gets compressed on subsequent accesses + new TNode(k, v, ct.computeHash(k)) // create it tombed so that it gets compressed on subsequent accesses } } def get(k: K) = listmap.get(k) @@ -627,25 +631,34 @@ private[concurrent] case class RDCSS_Descriptor[K, V](old: INode[K, V], expected * @since 2.10 */ @SerialVersionUID(0L - 6402774413839597105L) -final class TrieMap[K, V] private (r: AnyRef, rtupd: AtomicReferenceFieldUpdater[TrieMap[K, V], AnyRef]) +final class TrieMap[K, V] private (r: AnyRef, rtupd: AtomicReferenceFieldUpdater[TrieMap[K, V], AnyRef], hashf: Hashing[K], ef: Equiv[K]) extends scala.collection.concurrent.Map[K, V] with scala.collection.mutable.MapLike[K, V, TrieMap[K, V]] with CustomParallelizable[(K, V), ParTrieMap[K, V]] with Serializable { - import TrieMap.computeHash - + private var hashingobj = if (hashf.isInstanceOf[Hashing.Default[_]]) new TrieMap.MangledHashing[K] else hashf + private var equalityobj = ef private var rootupdater = rtupd + def hashing = hashingobj + def equality = equalityobj @volatile var root = r - - def this() = this( + + def this(hashf: Hashing[K], ef: Equiv[K]) = this( INode.newRootNode, - AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root") + AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root"), + hashf, + ef ) - + + def this() = this(Hashing.default, Equiv.universal) + /* internal methods */ private def writeObject(out: java.io.ObjectOutputStream) { + out.writeObject(hashf) + out.writeObject(ef) + val it = iterator while (it.hasNext) { val (k, v) = it.next() @@ -659,6 +672,9 @@ extends scala.collection.concurrent.Map[K, V] root = INode.newRootNode rootupdater = AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root") + hashingobj = in.readObject().asInstanceOf[Hashing[K]] + equalityobj = in.readObject().asInstanceOf[Equiv[K]] + var obj: AnyRef = null do { obj = in.readObject() @@ -780,7 +796,7 @@ extends scala.collection.concurrent.Map[K, V] @tailrec final def snapshot(): TrieMap[K, V] = { val r = RDCSS_READ_ROOT() val expmain = r.gcasRead(this) - if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r.copyToGen(new Gen, this), rootupdater) + if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r.copyToGen(new Gen, this), rootupdater, hashing, equality) else snapshot() } @@ -799,7 +815,7 @@ extends scala.collection.concurrent.Map[K, V] @tailrec final def readOnlySnapshot(): collection.Map[K, V] = { val r = RDCSS_READ_ROOT() val expmain = r.gcasRead(this) - if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r, null) + if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r, null, hashing, equality) else readOnlySnapshot() } @@ -807,7 +823,10 @@ extends scala.collection.concurrent.Map[K, V] val r = RDCSS_READ_ROOT() if (!RDCSS_ROOT(r, r.gcasRead(this), INode.newRootNode[K, V])) clear() } - + + @inline + def computeHash(k: K) = hashingobj.hashCode(k) + final def lookup(k: K): V = { val hc = computeHash(k) lookuphc(k, hc).asInstanceOf[V] @@ -894,14 +913,15 @@ object TrieMap extends MutableMapFactory[TrieMap] { implicit def canBuildFrom[K, V]: CanBuildFrom[Coll, (K, V), TrieMap[K, V]] = new MapCanBuildFrom[K, V] def empty[K, V]: TrieMap[K, V] = new TrieMap[K, V] - - @inline final def computeHash[K](k: K): Int = { - var hcode = k.hashCode - hcode = hcode * 0x9e3775cd - hcode = java.lang.Integer.reverseBytes(hcode) - hcode * 0x9e3775cd + + class MangledHashing[K] extends Hashing[K] { + def hashCode(k: K) = { + var hcode = k.## * 0x9e3775cd + hcode = java.lang.Integer.reverseBytes(hcode) + hcode * 0x9e3775cd + } } - + } |