summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Phillips <paulp@improving.org>2009-11-04 06:36:03 +0000
committerPaul Phillips <paulp@improving.org>2009-11-04 06:36:03 +0000
commitf5ede0923cd8dd2d45853a06e71e05052e462213 (patch)
tree731d9f9314e9a1762dd9f513cce15dff8b231829
parentb02b388ffaa5ea6d62669bb91d8af6de521aa71f (diff)
downloadscala-f5ede0923cd8dd2d45853a06e71e05052e462213.tar.gz
scala-f5ede0923cd8dd2d45853a06e71e05052e462213.tar.bz2
scala-f5ede0923cd8dd2d45853a06e71e05052e462213.zip
Scala implementation of fancier hashCode algori...
Scala implementation of fancier hashCode algorithm. At the moment it isn't used unless you supply -Yjenkins-hashCodes to scalac. Without the flag, the supplied test case generates 12559 unique hashCodes among 90000 case class instances; with the flag it generates 89999.
-rw-r--r--src/compiler/scala/tools/nsc/Settings.scala1
-rw-r--r--src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala11
-rw-r--r--src/library/scala/runtime/ScalaRunTime.scala3
-rw-r--r--src/library/scala/util/JenkinsHash.scala191
-rw-r--r--test/files/run/hashCodeDistribution.flags1
-rw-r--r--test/files/run/hashCodeDistribution.scala17
6 files changed, 220 insertions, 4 deletions
diff --git a/src/compiler/scala/tools/nsc/Settings.scala b/src/compiler/scala/tools/nsc/Settings.scala
index f6221ac14b..95c3a56767 100644
--- a/src/compiler/scala/tools/nsc/Settings.scala
+++ b/src/compiler/scala/tools/nsc/Settings.scala
@@ -837,6 +837,7 @@ trait ScalacSettings {
val Ypmatdebug = BooleanSetting ("-Ypmat-debug", "Trace all pattern matcher activity.")
val Ytailrec = BooleanSetting ("-Ytailrecommend", "Alert methods which would be tail-recursive if private or final.")
val YhigherKindedRaw = BooleanSetting ("-Yhigher-kinded-raw", "(temporary!) Treat raw Java types as higher-kinded types.")
+ val Yjenkins = BooleanSetting ("-Yjenkins-hashCodes", "Use jenkins hash algorithm for case class generated hashCodes.")
// Equality specific
val logEqEq = BooleanSetting ("-Ylog-eqeq", "Log all noteworthy equality tests") .
diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
index e3d75752f2..a3f628ebb4 100644
--- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
+++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala
@@ -109,8 +109,8 @@ trait SyntheticMethods extends ast.TreeDSL {
typer typed { DEF(method) === LIT(clazz.name.decode) }
}
- def forwardingMethod(name: Name): Tree = {
- val target = getMember(ScalaRunTimeModule, "_" + name)
+ def forwardingMethod(name: Name, targetName: Name): Tree = {
+ val target = getMember(ScalaRunTimeModule, targetName)
val paramtypes = target.tpe.paramTypes drop 1
val method = syntheticMethod(
name, 0, makeTypeConstructor(paramtypes, target.tpe.resultType)
@@ -123,6 +123,9 @@ trait SyntheticMethods extends ast.TreeDSL {
}
}
+ def hashCodeTarget: Name =
+ if (settings.Yjenkins.value) "hashCodeJenkins" else nme.hashCode_
+
def equalsSym = syntheticMethod(
nme.equals_, 0, makeTypeConstructor(List(AnyClass.tpe), BooleanClass.tpe)
)
@@ -244,8 +247,8 @@ trait SyntheticMethods extends ast.TreeDSL {
// methods for case classes only
def classMethods = List(
- Object_hashCode -> (() => forwardingMethod(nme.hashCode_)),
- Object_toString -> (() => forwardingMethod(nme.toString_)),
+ Object_hashCode -> (() => forwardingMethod(nme.hashCode_, "_" + hashCodeTarget)),
+ Object_toString -> (() => forwardingMethod(nme.toString_, "_" + nme.toString_)),
Object_equals -> (() => equalsClassMethod)
)
// methods for case objects only
diff --git a/src/library/scala/runtime/ScalaRunTime.scala b/src/library/scala/runtime/ScalaRunTime.scala
index dafa82916e..4541c0147c 100644
--- a/src/library/scala/runtime/ScalaRunTime.scala
+++ b/src/library/scala/runtime/ScalaRunTime.scala
@@ -105,6 +105,9 @@ object ScalaRunTime {
def _toString(x: Product): String =
caseFields(x).mkString(x.productPrefix + "(", ",", ")")
+ def _hashCodeJenkins(x: Product): Int =
+ scala.util.JenkinsHash.hashSeq(x.productPrefix.toSeq ++ x.productIterator.toSeq)
+
def _hashCode(x: Product): Int = {
var code = x.productPrefix.hashCode()
val arr = x.productArity
diff --git a/src/library/scala/util/JenkinsHash.scala b/src/library/scala/util/JenkinsHash.scala
new file mode 100644
index 0000000000..4b1198bbc5
--- /dev/null
+++ b/src/library/scala/util/JenkinsHash.scala
@@ -0,0 +1,191 @@
+/* __ *\
+** ________ ___ / / ___ Scala API **
+** / __/ __// _ | / / / _ | (c) 2002-2009, LAMP/EPFL **
+** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ **
+** /____/\___/_/ |_/____/_/ | | **
+** |/ **
+\* */
+
+package scala.util
+
+import java.nio.ByteBuffer
+
+/**
+ * Original algorithm due to Bob Jenkins.
+ * http://burtleburtle.net/bob/c/lookup3.c
+ * Scala version partially adapted from java version by Gray Watson.
+ * http://256.com/sources/jenkins_hash_java/JenkinsHash.java
+ *
+ * This is based on the 1996 version, not the 2006 version, and
+ * could most likely stand some improvement; the collision rate is
+ * negligible in my tests, but performance merits investigation.
+ *
+ * @author Paul Phillips
+ */
+
+object JenkinsHash {
+ final val MAX_VALUE = 0xFFFFFFFFL
+
+ private def bytesProvided(v: Any) = v match {
+ case x: Byte => 1
+ case x: Short => 2
+ case x: Int => 4
+ case x: Long => 8
+ case x: Float => 4
+ case x: Double => 8
+ case x: Boolean => 1
+ case x: Char => 2
+ case x: Unit => 0
+ case _ => 4
+ }
+
+ private def putAnyVal(bb: ByteBuffer, v: AnyVal) = v match {
+ case x: Byte => bb put x
+ case x: Short => bb putShort x
+ case x: Int => bb putInt x
+ case x: Long => bb putLong x
+ case x: Float => bb putFloat x
+ case x: Double => bb putDouble x
+ case x: Boolean => if (x) bb put Byte.MaxValue else Byte.MinValue
+ case x: Char => bb putChar x
+ case x: Unit =>
+ }
+
+ /** Not entirely sure how else one might do this these days, since
+ * matching on x: AnyVal is a compile time error.
+ */
+ private def classifyAny(x: Any): (Option[AnyVal], Option[AnyRef]) = x match {
+ case x: Byte => (Some(x), None)
+ case x: Short => (Some(x), None)
+ case x: Int => (Some(x), None)
+ case x: Long => (Some(x), None)
+ case x: Float => (Some(x), None)
+ case x: Double => (Some(x), None)
+ case x: Boolean => (Some(x), None)
+ case x: Char => (Some(x), None)
+ case x: Unit => (Some(x), None)
+ case x: AnyRef => (None, Some(x))
+ }
+
+ private def partitionValuesAndRefs(xs: Seq[Any]): (Seq[AnyVal], Seq[AnyRef]) = {
+ val (avs, ars) = xs map classifyAny unzip
+
+ (avs.flatten, ars.flatten)
+ }
+
+ private def hashAnyValSeq(xs: Seq[AnyVal]): Int = {
+ val arr = new Array[Byte](xs map bytesProvided sum)
+ val bb = ByteBuffer wrap arr
+ xs foreach (x => putAnyVal(bb, x))
+
+ hash(bb.array()).toInt
+ }
+
+ /**
+ * Convert a byte into a long value without making it negative.
+ */
+ private def byteToLong(b: Byte): Long = {
+ val res = b & 0x7F
+ if ((b & 0x80) != 0L) res + 128
+ else res
+ }
+
+ /**
+ * Do addition and turn into 4 bytes.
+ */
+ private def add(x1: Long, x2: Long) = (x1 + x2) & MAX_VALUE
+
+ /**
+ * Do subtraction and turn into 4 bytes.
+ */
+ private def subtract(x1: Long, x2: Long) = (x1 - x2) & MAX_VALUE
+
+ /**
+ * Left shift val by shift bits and turn in 4 bytes.
+ */
+ private def xor(x1: Long, x2: Long) = (x1 ^ x2) & MAX_VALUE
+
+ /**
+ * Left shift val by shift bits. Cut down to 4 bytes.
+ */
+ private def leftShift(x: Long, shift: Int) = (x << shift) & MAX_VALUE
+
+ /**
+ * Convert 4 bytes from the buffer at offset into a long value.
+ */
+ private def fourByteToLong(bytes: Array[Byte], offset: Int) =
+ 0 to 3 map (i => byteToLong(bytes(offset + i)) << (i * 8)) sum
+
+ /**
+ * Hash a sequence of anything into a 32-bit value. Descendants
+ * of AnyVal are broken down into individual bytes and mixed with
+ * some vigor, and this is summed with the hashCodes provided by
+ * the descendants of AnyRef.
+ */
+ def hashSeq(xs: Seq[Any]): Int = {
+ val (values, refs) = partitionValuesAndRefs(xs)
+ val refsSum = refs map (x => if (x == null) 0 else x.hashCode) sum
+
+ hashAnyValSeq(values) + refsSum
+ }
+
+ /**
+ * Hash a variable-length key into a 32-bit value. Every bit of the
+ * key affects every bit of the return value. Every 1-bit and 2-bit
+ * delta achieves avalanche. The best hash table sizes are powers of 2.
+ *
+ * @param buffer Byte array that we are hashing on.
+ * @param initialValue Initial value of the hash if we are continuing from
+ * a previous run. 0 if none.
+ * @return Hash value for the buffer.
+ */
+ def hash(buffer: Array[Byte], initialValue: Long = 0L): Long = {
+ var a, b = 0x09e3779b9L
+ var c = initialValue
+
+ def hashMix(): Long = {
+ a = subtract(a, b); a = subtract(a, c); a = xor(a, c >> 13);
+ b = subtract(b, c); b = subtract(b, a); b = xor(b, leftShift(a, 8));
+ c = subtract(c, a); c = subtract(c, b); c = xor(c, (b >> 13));
+ a = subtract(a, b); a = subtract(a, c); a = xor(a, (c >> 12));
+ b = subtract(b, c); b = subtract(b, a); b = xor(b, leftShift(a, 16));
+ c = subtract(c, a); c = subtract(c, b); c = xor(c, (b >> 5));
+ a = subtract(a, b); a = subtract(a, c); a = xor(a, (c >> 3));
+ b = subtract(b, c); b = subtract(b, a); b = xor(b, leftShift(a, 10));
+ c = subtract(c, a); c = subtract(c, b); c = xor(c, (b >> 15));
+
+ c
+ }
+
+ def mixTwelve(pos: Int) = {
+ a = add(a, fourByteToLong(buffer, pos));
+ b = add(b, fourByteToLong(buffer, pos + 4));
+ c = add(c, fourByteToLong(buffer, pos + 8));
+ hashMix()
+ }
+
+ // mix in blocks of 12
+ var pos: Int = buffer.length
+ while (pos >= 12) {
+ pos -= 12
+ mixTwelve(pos)
+ }
+ c += buffer.length
+
+ // mix any leftover bytes (0-11 remaining)
+ if (pos > 10) c = add(c, leftShift(byteToLong(buffer(10)), 24))
+ if (pos > 9) c = add(c, leftShift(byteToLong(buffer(9)), 16))
+ if (pos > 8) c = add(c, leftShift(byteToLong(buffer(8)), 8))
+ if (pos > 7) b = add(b, leftShift(byteToLong(buffer(7)), 24))
+ if (pos > 6) b = add(b, leftShift(byteToLong(buffer(6)), 16))
+ if (pos > 5) b = add(b, leftShift(byteToLong(buffer(5)), 8))
+ if (pos > 4) b = add(b, byteToLong(buffer(4)))
+ if (pos > 3) a = add(a, leftShift(byteToLong(buffer(3)), 24))
+ if (pos > 2) a = add(a, leftShift(byteToLong(buffer(2)), 16))
+ if (pos > 1) a = add(a, leftShift(byteToLong(buffer(1)), 8))
+ if (pos > 0) a = add(a, byteToLong(buffer(0)))
+
+ // final mix and result
+ hashMix()
+ }
+}
diff --git a/test/files/run/hashCodeDistribution.flags b/test/files/run/hashCodeDistribution.flags
new file mode 100644
index 0000000000..7806652d4d
--- /dev/null
+++ b/test/files/run/hashCodeDistribution.flags
@@ -0,0 +1 @@
+-Yjenkins-hashCodes \ No newline at end of file
diff --git a/test/files/run/hashCodeDistribution.scala b/test/files/run/hashCodeDistribution.scala
new file mode 100644
index 0000000000..dbb6e833bd
--- /dev/null
+++ b/test/files/run/hashCodeDistribution.scala
@@ -0,0 +1,17 @@
+// See ticket #2537.
+object Test {
+ case class C(x: Int, y: Int) { }
+ val COUNT = 300
+ val totalCodes = COUNT * COUNT
+
+ def main (args: Array[String]) = {
+ val hashCodes =
+ for (x <- 0 until COUNT; y <- 0 until COUNT) yield C(x,y).hashCode
+
+ val uniques = hashCodes.removeDuplicates
+ val collisionRate = (totalCodes - uniques.size) * 1000 / totalCodes
+
+ assert(collisionRate < 5, "Collision rate too high: %d / 1000".format(collisionRate))
+ // println("collisionRate = %d / 1000".format(collisionRate))
+ }
+} \ No newline at end of file