summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Suereth <Joshua.Suereth@gmail.com>2012-06-04 06:27:14 -0700
committerJosh Suereth <Joshua.Suereth@gmail.com>2012-06-04 06:27:14 -0700
commite90ea27cbeff305860a583160373a2af5b59581d (patch)
tree73f6ff0522649e2c52d7908028e3b0d7df38911b
parentc84e019622f4e5458530e1d45fd735e4a90efd75 (diff)
parentddaf6c5235acbf808ffd5ded710f513f4df3882e (diff)
downloadscala-e90ea27cbeff305860a583160373a2af5b59581d.tar.gz
scala-e90ea27cbeff305860a583160373a2af5b59581d.tar.bz2
scala-e90ea27cbeff305860a583160373a2af5b59581d.zip
Merge pull request #655 from axel22/feature/hasher
Implementing Hashing typeclass
-rw-r--r--src/library/scala/collection/concurrent/TrieMap.scala76
-rw-r--r--src/library/scala/math/Equiv.scala2
-rw-r--r--src/library/scala/util/MurmurHash3.scala8
-rw-r--r--src/library/scala/util/hashing/Hashing.scala42
-rw-r--r--test/files/run/triemap-hash.scala46
5 files changed, 145 insertions, 29 deletions
diff --git a/src/library/scala/collection/concurrent/TrieMap.scala b/src/library/scala/collection/concurrent/TrieMap.scala
index 2a908aebb1..f944d20bdb 100644
--- a/src/library/scala/collection/concurrent/TrieMap.scala
+++ b/src/library/scala/collection/concurrent/TrieMap.scala
@@ -14,6 +14,7 @@ package concurrent
import java.util.concurrent.atomic._
import collection.immutable.{ ListMap => ImmutableListMap }
import collection.parallel.mutable.ParTrieMap
+import util.hashing.Hashing
import generic._
import annotation.tailrec
import annotation.switch
@@ -80,6 +81,9 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
} else false
}
+ @inline
+ private def equal(k1: K, k2: K, ct: TrieMap[K, V]) = ct.equality.equiv(k1, k2)
+
@inline private def inode(cn: MainNode[K, V]) = {
val nin = new INode[K, V](gen)
nin.WRITE(cn)
@@ -117,7 +121,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
else false
}
case sn: SNode[K, V] =>
- if (sn.hc == hc && sn.k == k) GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)
+ if (sn.hc == hc && equal(sn.k, k, ct)) GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)
else {
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
val nn = rn.updatedAt(pos, inode(CNode.dual(sn, sn.hc, new SNode(k, v, hc), hc, lev + 5, gen)), gen)
@@ -164,7 +168,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
}
case sn: SNode[K, V] => cond match {
case null =>
- if (sn.hc == hc && sn.k == k) {
+ if (sn.hc == hc && equal(sn.k, k, ct)) {
if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null
} else {
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
@@ -173,7 +177,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
else null
}
case INode.KEY_ABSENT =>
- if (sn.hc == hc && sn.k == k) Some(sn.v)
+ if (sn.hc == hc && equal(sn.k, k, ct)) Some(sn.v)
else {
val rn = if (cn.gen eq gen) cn else cn.renewed(gen, ct)
val nn = rn.updatedAt(pos, inode(CNode.dual(sn, sn.hc, new SNode(k, v, hc), hc, lev + 5, gen)), gen)
@@ -181,11 +185,11 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
else null
}
case INode.KEY_PRESENT =>
- if (sn.hc == hc && sn.k == k) {
+ if (sn.hc == hc && equal(sn.k, k, ct)) {
if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null
} else None
case otherv: V =>
- if (sn.hc == hc && sn.k == k && sn.v == otherv) {
+ if (sn.hc == hc && equal(sn.k, k, ct) && sn.v == otherv) {
if (GCAS(cn, cn.updatedAt(pos, new SNode(k, v, hc), gen), ct)) Some(sn.v) else null
} else None
}
@@ -253,7 +257,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
else return RESTART // used to be throw RestartException
}
case sn: SNode[K, V] => // 2) singleton node
- if (sn.hc == hc && sn.k == k) sn.v.asInstanceOf[AnyRef]
+ if (sn.hc == hc && equal(sn.k, k, ct)) sn.v.asInstanceOf[AnyRef]
else null
}
}
@@ -296,7 +300,7 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
else null
}
case sn: SNode[K, V] =>
- if (sn.hc == hc && sn.k == k && (v == null || sn.v == v)) {
+ if (sn.hc == hc && equal(sn.k, k, ct) && (v == null || sn.v == v)) {
val ncn = cn.removedAt(pos, flag, gen).toContracted(lev)
if (GCAS(cn, ncn, ct)) Some(sn.v) else null
} else None
@@ -341,11 +345,11 @@ private[collection] final class INode[K, V](bn: MainNode[K, V], g: Gen) extends
case ln: LNode[K, V] =>
if (v == null) {
val optv = ln.get(k)
- val nn = ln.removed(k)
+ val nn = ln.removed(k, ct)
if (GCAS(ln, nn, ct)) optv else null
} else ln.get(k) match {
case optv @ Some(v0) if v0 == v =>
- val nn = ln.removed(k)
+ val nn = ln.removed(k, ct)
if (GCAS(ln, nn, ct)) optv else null
case _ => None
}
@@ -433,12 +437,12 @@ extends MainNode[K, V] {
def this(k: K, v: V) = this(ImmutableListMap(k -> v))
def this(k1: K, v1: V, k2: K, v2: V) = this(ImmutableListMap(k1 -> v1, k2 -> v2))
def inserted(k: K, v: V) = new LNode(listmap + ((k, v)))
- def removed(k: K): MainNode[K, V] = {
+ def removed(k: K, ct: TrieMap[K, V]): MainNode[K, V] = {
val updmap = listmap - k
if (updmap.size > 1) new LNode(updmap)
else {
val (k, v) = updmap.iterator.next
- new TNode(k, v, TrieMap.computeHash(k)) // create it tombed so that it gets compressed on subsequent accesses
+ new TNode(k, v, ct.computeHash(k)) // create it tombed so that it gets compressed on subsequent accesses
}
}
def get(k: K) = listmap.get(k)
@@ -627,25 +631,34 @@ private[concurrent] case class RDCSS_Descriptor[K, V](old: INode[K, V], expected
* @since 2.10
*/
@SerialVersionUID(0L - 6402774413839597105L)
-final class TrieMap[K, V] private (r: AnyRef, rtupd: AtomicReferenceFieldUpdater[TrieMap[K, V], AnyRef])
+final class TrieMap[K, V] private (r: AnyRef, rtupd: AtomicReferenceFieldUpdater[TrieMap[K, V], AnyRef], hashf: Hashing[K], ef: Equiv[K])
extends scala.collection.concurrent.Map[K, V]
with scala.collection.mutable.MapLike[K, V, TrieMap[K, V]]
with CustomParallelizable[(K, V), ParTrieMap[K, V]]
with Serializable
{
- import TrieMap.computeHash
-
+ private var hashingobj = if (hashf.isInstanceOf[Hashing.Default[_]]) new TrieMap.MangledHashing[K] else hashf
+ private var equalityobj = ef
private var rootupdater = rtupd
+ def hashing = hashingobj
+ def equality = equalityobj
@volatile var root = r
-
- def this() = this(
+
+ def this(hashf: Hashing[K], ef: Equiv[K]) = this(
INode.newRootNode,
- AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root")
+ AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root"),
+ hashf,
+ ef
)
-
+
+ def this() = this(Hashing.default, Equiv.universal)
+
/* internal methods */
private def writeObject(out: java.io.ObjectOutputStream) {
+ out.writeObject(hashf)
+ out.writeObject(ef)
+
val it = iterator
while (it.hasNext) {
val (k, v) = it.next()
@@ -659,6 +672,9 @@ extends scala.collection.concurrent.Map[K, V]
root = INode.newRootNode
rootupdater = AtomicReferenceFieldUpdater.newUpdater(classOf[TrieMap[K, V]], classOf[AnyRef], "root")
+ hashingobj = in.readObject().asInstanceOf[Hashing[K]]
+ equalityobj = in.readObject().asInstanceOf[Equiv[K]]
+
var obj: AnyRef = null
do {
obj = in.readObject()
@@ -780,7 +796,7 @@ extends scala.collection.concurrent.Map[K, V]
@tailrec final def snapshot(): TrieMap[K, V] = {
val r = RDCSS_READ_ROOT()
val expmain = r.gcasRead(this)
- if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r.copyToGen(new Gen, this), rootupdater)
+ if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r.copyToGen(new Gen, this), rootupdater, hashing, equality)
else snapshot()
}
@@ -799,7 +815,7 @@ extends scala.collection.concurrent.Map[K, V]
@tailrec final def readOnlySnapshot(): collection.Map[K, V] = {
val r = RDCSS_READ_ROOT()
val expmain = r.gcasRead(this)
- if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r, null)
+ if (RDCSS_ROOT(r, expmain, r.copyToGen(new Gen, this))) new TrieMap(r, null, hashing, equality)
else readOnlySnapshot()
}
@@ -807,7 +823,10 @@ extends scala.collection.concurrent.Map[K, V]
val r = RDCSS_READ_ROOT()
if (!RDCSS_ROOT(r, r.gcasRead(this), INode.newRootNode[K, V])) clear()
}
-
+
+ @inline
+ def computeHash(k: K) = hashingobj.hashCode(k)
+
final def lookup(k: K): V = {
val hc = computeHash(k)
lookuphc(k, hc).asInstanceOf[V]
@@ -894,14 +913,15 @@ object TrieMap extends MutableMapFactory[TrieMap] {
implicit def canBuildFrom[K, V]: CanBuildFrom[Coll, (K, V), TrieMap[K, V]] = new MapCanBuildFrom[K, V]
def empty[K, V]: TrieMap[K, V] = new TrieMap[K, V]
-
- @inline final def computeHash[K](k: K): Int = {
- var hcode = k.hashCode
- hcode = hcode * 0x9e3775cd
- hcode = java.lang.Integer.reverseBytes(hcode)
- hcode * 0x9e3775cd
+
+ class MangledHashing[K] extends Hashing[K] {
+ def hashCode(k: K) = {
+ var hcode = k.## * 0x9e3775cd
+ hcode = java.lang.Integer.reverseBytes(hcode)
+ hcode * 0x9e3775cd
+ }
}
-
+
}
diff --git a/src/library/scala/math/Equiv.scala b/src/library/scala/math/Equiv.scala
index bd8414a18d..a8ba0aa40c 100644
--- a/src/library/scala/math/Equiv.scala
+++ b/src/library/scala/math/Equiv.scala
@@ -29,7 +29,7 @@ import java.util.Comparator
* @since 2.7
*/
-trait Equiv[T] extends Any {
+trait Equiv[T] extends Any with Serializable {
/** Returns `true` iff `x` is equivalent to `y`.
*/
def equiv(x: T, y: T): Boolean
diff --git a/src/library/scala/util/MurmurHash3.scala b/src/library/scala/util/MurmurHash3.scala
index fbb67116dd..9a7f64b4ee 100644
--- a/src/library/scala/util/MurmurHash3.scala
+++ b/src/library/scala/util/MurmurHash3.scala
@@ -1,3 +1,11 @@
+/* __ *\
+** ________ ___ / / ___ Scala API **
+** / __/ __// _ | / / / _ | (c) 2003-2012, LAMP/EPFL **
+** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ **
+** /____/\___/_/ |_/____/_/ | | **
+** |/ **
+\* */
+
package scala.util
import java.lang.Integer.{ rotateLeft => rotl }
diff --git a/src/library/scala/util/hashing/Hashing.scala b/src/library/scala/util/hashing/Hashing.scala
new file mode 100644
index 0000000000..f27a825125
--- /dev/null
+++ b/src/library/scala/util/hashing/Hashing.scala
@@ -0,0 +1,42 @@
+/* __ *\
+** ________ ___ / / ___ Scala API **
+** / __/ __// _ | / / / _ | (c) 2003-2011, LAMP/EPFL **
+** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ **
+** /____/\___/_/ |_/____/_/ | | **
+** |/ **
+\* */
+
+package scala.util.hashing
+
+/** `Hashing` is a trait whose instances each represent a strategy for hashing
+ * instances of a type.
+ *
+ * `Hashing`'s companion object defines a default hashing strategy for all
+ * objects - it calls their `##` method.
+ *
+ * Note: when using a custom `Hashing`, make sure to use it with the `Equiv`
+ * such that if any two objects are equal, then their hash codes must be equal.
+ *
+ * @since 2.10
+ */
+@annotation.implicitNotFound(msg = "No implicit Hashing defined for ${T}.")
+trait Hashing[T] extends Serializable {
+
+ def hashCode(x: T): Int
+
+}
+
+
+object Hashing {
+
+ final class Default[T] extends Hashing[T] {
+ def hashCode(x: T) = x.##
+ }
+
+ implicit def default[T] = new Default[T]
+
+ def fromFunction[T](f: T => Int) = new Hashing[T] {
+ def hashCode(x: T) = f(x)
+ }
+
+}
diff --git a/test/files/run/triemap-hash.scala b/test/files/run/triemap-hash.scala
new file mode 100644
index 0000000000..7f19997da0
--- /dev/null
+++ b/test/files/run/triemap-hash.scala
@@ -0,0 +1,46 @@
+
+
+
+import util.hashing.Hashing
+
+
+
+object Test {
+
+ def main(args: Array[String]) {
+ hashing()
+ equality()
+ }
+
+ def hashing() {
+ import collection._
+
+ val tm = new concurrent.TrieMap[String, String](Hashing.fromFunction(x => x.length + x(0).toInt), Equiv.universal)
+ tm.put("a", "b")
+ tm.put("c", "d")
+
+ assert(tm("a") == "b")
+ assert(tm("c") == "d")
+
+ for (i <- 0 until 1000) tm(i.toString) = i.toString
+ for (i <- 0 until 1000) assert(tm(i.toString) == i.toString)
+ }
+
+ def equality() {
+ import collection._
+
+ val tm = new concurrent.TrieMap[String, String](Hashing.fromFunction(x => x(0).toInt), Equiv.fromFunction(_(0) == _(0)))
+ tm.put("a", "b")
+ tm.put("a1", "d")
+ tm.put("b", "c")
+
+ assert(tm("a") == "d", tm)
+ assert(tm("b") == "c", tm)
+
+ for (i <- 0 until 1000) tm(i.toString) = i.toString
+ assert(tm.size == 12, tm)
+ assert(tm("0") == "0", tm)
+ for (i <- 1 to 9) assert(tm(i.toString) == i.toString + "99", tm)
+ }
+
+}