From f5ede0923cd8dd2d45853a06e71e05052e462213 Mon Sep 17 00:00:00 2001 From: Paul Phillips Date: Wed, 4 Nov 2009 06:36:03 +0000 Subject: Scala implementation of fancier hashCode algori... Scala implementation of fancier hashCode algorithm. At the moment it isn't used unless you supply -Yjenkins-hashCodes to scalac. Without the flag, the supplied test case generates 12559 unique hashCodes among 90000 case class instances; with the flag it generates 89999. --- src/compiler/scala/tools/nsc/Settings.scala | 1 + .../tools/nsc/typechecker/SyntheticMethods.scala | 11 +- src/library/scala/runtime/ScalaRunTime.scala | 3 + src/library/scala/util/JenkinsHash.scala | 191 +++++++++++++++++++++ test/files/run/hashCodeDistribution.flags | 1 + test/files/run/hashCodeDistribution.scala | 17 ++ 6 files changed, 220 insertions(+), 4 deletions(-) create mode 100644 src/library/scala/util/JenkinsHash.scala create mode 100644 test/files/run/hashCodeDistribution.flags create mode 100644 test/files/run/hashCodeDistribution.scala diff --git a/src/compiler/scala/tools/nsc/Settings.scala b/src/compiler/scala/tools/nsc/Settings.scala index f6221ac14b..95c3a56767 100644 --- a/src/compiler/scala/tools/nsc/Settings.scala +++ b/src/compiler/scala/tools/nsc/Settings.scala @@ -837,6 +837,7 @@ trait ScalacSettings { val Ypmatdebug = BooleanSetting ("-Ypmat-debug", "Trace all pattern matcher activity.") val Ytailrec = BooleanSetting ("-Ytailrecommend", "Alert methods which would be tail-recursive if private or final.") val YhigherKindedRaw = BooleanSetting ("-Yhigher-kinded-raw", "(temporary!) Treat raw Java types as higher-kinded types.") + val Yjenkins = BooleanSetting ("-Yjenkins-hashCodes", "Use jenkins hash algorithm for case class generated hashCodes.") // Equality specific val logEqEq = BooleanSetting ("-Ylog-eqeq", "Log all noteworthy equality tests") . diff --git a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala index e3d75752f2..a3f628ebb4 100644 --- a/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala +++ b/src/compiler/scala/tools/nsc/typechecker/SyntheticMethods.scala @@ -109,8 +109,8 @@ trait SyntheticMethods extends ast.TreeDSL { typer typed { DEF(method) === LIT(clazz.name.decode) } } - def forwardingMethod(name: Name): Tree = { - val target = getMember(ScalaRunTimeModule, "_" + name) + def forwardingMethod(name: Name, targetName: Name): Tree = { + val target = getMember(ScalaRunTimeModule, targetName) val paramtypes = target.tpe.paramTypes drop 1 val method = syntheticMethod( name, 0, makeTypeConstructor(paramtypes, target.tpe.resultType) @@ -123,6 +123,9 @@ trait SyntheticMethods extends ast.TreeDSL { } } + def hashCodeTarget: Name = + if (settings.Yjenkins.value) "hashCodeJenkins" else nme.hashCode_ + def equalsSym = syntheticMethod( nme.equals_, 0, makeTypeConstructor(List(AnyClass.tpe), BooleanClass.tpe) ) @@ -244,8 +247,8 @@ trait SyntheticMethods extends ast.TreeDSL { // methods for case classes only def classMethods = List( - Object_hashCode -> (() => forwardingMethod(nme.hashCode_)), - Object_toString -> (() => forwardingMethod(nme.toString_)), + Object_hashCode -> (() => forwardingMethod(nme.hashCode_, "_" + hashCodeTarget)), + Object_toString -> (() => forwardingMethod(nme.toString_, "_" + nme.toString_)), Object_equals -> (() => equalsClassMethod) ) // methods for case objects only diff --git a/src/library/scala/runtime/ScalaRunTime.scala b/src/library/scala/runtime/ScalaRunTime.scala index dafa82916e..4541c0147c 100644 --- a/src/library/scala/runtime/ScalaRunTime.scala +++ b/src/library/scala/runtime/ScalaRunTime.scala @@ -105,6 +105,9 @@ object ScalaRunTime { def _toString(x: Product): String = caseFields(x).mkString(x.productPrefix + "(", ",", ")") + def _hashCodeJenkins(x: Product): Int = + scala.util.JenkinsHash.hashSeq(x.productPrefix.toSeq ++ x.productIterator.toSeq) + def _hashCode(x: Product): Int = { var code = x.productPrefix.hashCode() val arr = x.productArity diff --git a/src/library/scala/util/JenkinsHash.scala b/src/library/scala/util/JenkinsHash.scala new file mode 100644 index 0000000000..4b1198bbc5 --- /dev/null +++ b/src/library/scala/util/JenkinsHash.scala @@ -0,0 +1,191 @@ +/* __ *\ +** ________ ___ / / ___ Scala API ** +** / __/ __// _ | / / / _ | (c) 2002-2009, LAMP/EPFL ** +** __\ \/ /__/ __ |/ /__/ __ | http://scala-lang.org/ ** +** /____/\___/_/ |_/____/_/ | | ** +** |/ ** +\* */ + +package scala.util + +import java.nio.ByteBuffer + +/** + * Original algorithm due to Bob Jenkins. + * http://burtleburtle.net/bob/c/lookup3.c + * Scala version partially adapted from java version by Gray Watson. + * http://256.com/sources/jenkins_hash_java/JenkinsHash.java + * + * This is based on the 1996 version, not the 2006 version, and + * could most likely stand some improvement; the collision rate is + * negligible in my tests, but performance merits investigation. + * + * @author Paul Phillips + */ + +object JenkinsHash { + final val MAX_VALUE = 0xFFFFFFFFL + + private def bytesProvided(v: Any) = v match { + case x: Byte => 1 + case x: Short => 2 + case x: Int => 4 + case x: Long => 8 + case x: Float => 4 + case x: Double => 8 + case x: Boolean => 1 + case x: Char => 2 + case x: Unit => 0 + case _ => 4 + } + + private def putAnyVal(bb: ByteBuffer, v: AnyVal) = v match { + case x: Byte => bb put x + case x: Short => bb putShort x + case x: Int => bb putInt x + case x: Long => bb putLong x + case x: Float => bb putFloat x + case x: Double => bb putDouble x + case x: Boolean => if (x) bb put Byte.MaxValue else Byte.MinValue + case x: Char => bb putChar x + case x: Unit => + } + + /** Not entirely sure how else one might do this these days, since + * matching on x: AnyVal is a compile time error. + */ + private def classifyAny(x: Any): (Option[AnyVal], Option[AnyRef]) = x match { + case x: Byte => (Some(x), None) + case x: Short => (Some(x), None) + case x: Int => (Some(x), None) + case x: Long => (Some(x), None) + case x: Float => (Some(x), None) + case x: Double => (Some(x), None) + case x: Boolean => (Some(x), None) + case x: Char => (Some(x), None) + case x: Unit => (Some(x), None) + case x: AnyRef => (None, Some(x)) + } + + private def partitionValuesAndRefs(xs: Seq[Any]): (Seq[AnyVal], Seq[AnyRef]) = { + val (avs, ars) = xs map classifyAny unzip + + (avs.flatten, ars.flatten) + } + + private def hashAnyValSeq(xs: Seq[AnyVal]): Int = { + val arr = new Array[Byte](xs map bytesProvided sum) + val bb = ByteBuffer wrap arr + xs foreach (x => putAnyVal(bb, x)) + + hash(bb.array()).toInt + } + + /** + * Convert a byte into a long value without making it negative. + */ + private def byteToLong(b: Byte): Long = { + val res = b & 0x7F + if ((b & 0x80) != 0L) res + 128 + else res + } + + /** + * Do addition and turn into 4 bytes. + */ + private def add(x1: Long, x2: Long) = (x1 + x2) & MAX_VALUE + + /** + * Do subtraction and turn into 4 bytes. + */ + private def subtract(x1: Long, x2: Long) = (x1 - x2) & MAX_VALUE + + /** + * Left shift val by shift bits and turn in 4 bytes. + */ + private def xor(x1: Long, x2: Long) = (x1 ^ x2) & MAX_VALUE + + /** + * Left shift val by shift bits. Cut down to 4 bytes. + */ + private def leftShift(x: Long, shift: Int) = (x << shift) & MAX_VALUE + + /** + * Convert 4 bytes from the buffer at offset into a long value. + */ + private def fourByteToLong(bytes: Array[Byte], offset: Int) = + 0 to 3 map (i => byteToLong(bytes(offset + i)) << (i * 8)) sum + + /** + * Hash a sequence of anything into a 32-bit value. Descendants + * of AnyVal are broken down into individual bytes and mixed with + * some vigor, and this is summed with the hashCodes provided by + * the descendants of AnyRef. + */ + def hashSeq(xs: Seq[Any]): Int = { + val (values, refs) = partitionValuesAndRefs(xs) + val refsSum = refs map (x => if (x == null) 0 else x.hashCode) sum + + hashAnyValSeq(values) + refsSum + } + + /** + * Hash a variable-length key into a 32-bit value. Every bit of the + * key affects every bit of the return value. Every 1-bit and 2-bit + * delta achieves avalanche. The best hash table sizes are powers of 2. + * + * @param buffer Byte array that we are hashing on. + * @param initialValue Initial value of the hash if we are continuing from + * a previous run. 0 if none. + * @return Hash value for the buffer. + */ + def hash(buffer: Array[Byte], initialValue: Long = 0L): Long = { + var a, b = 0x09e3779b9L + var c = initialValue + + def hashMix(): Long = { + a = subtract(a, b); a = subtract(a, c); a = xor(a, c >> 13); + b = subtract(b, c); b = subtract(b, a); b = xor(b, leftShift(a, 8)); + c = subtract(c, a); c = subtract(c, b); c = xor(c, (b >> 13)); + a = subtract(a, b); a = subtract(a, c); a = xor(a, (c >> 12)); + b = subtract(b, c); b = subtract(b, a); b = xor(b, leftShift(a, 16)); + c = subtract(c, a); c = subtract(c, b); c = xor(c, (b >> 5)); + a = subtract(a, b); a = subtract(a, c); a = xor(a, (c >> 3)); + b = subtract(b, c); b = subtract(b, a); b = xor(b, leftShift(a, 10)); + c = subtract(c, a); c = subtract(c, b); c = xor(c, (b >> 15)); + + c + } + + def mixTwelve(pos: Int) = { + a = add(a, fourByteToLong(buffer, pos)); + b = add(b, fourByteToLong(buffer, pos + 4)); + c = add(c, fourByteToLong(buffer, pos + 8)); + hashMix() + } + + // mix in blocks of 12 + var pos: Int = buffer.length + while (pos >= 12) { + pos -= 12 + mixTwelve(pos) + } + c += buffer.length + + // mix any leftover bytes (0-11 remaining) + if (pos > 10) c = add(c, leftShift(byteToLong(buffer(10)), 24)) + if (pos > 9) c = add(c, leftShift(byteToLong(buffer(9)), 16)) + if (pos > 8) c = add(c, leftShift(byteToLong(buffer(8)), 8)) + if (pos > 7) b = add(b, leftShift(byteToLong(buffer(7)), 24)) + if (pos > 6) b = add(b, leftShift(byteToLong(buffer(6)), 16)) + if (pos > 5) b = add(b, leftShift(byteToLong(buffer(5)), 8)) + if (pos > 4) b = add(b, byteToLong(buffer(4))) + if (pos > 3) a = add(a, leftShift(byteToLong(buffer(3)), 24)) + if (pos > 2) a = add(a, leftShift(byteToLong(buffer(2)), 16)) + if (pos > 1) a = add(a, leftShift(byteToLong(buffer(1)), 8)) + if (pos > 0) a = add(a, byteToLong(buffer(0))) + + // final mix and result + hashMix() + } +} diff --git a/test/files/run/hashCodeDistribution.flags b/test/files/run/hashCodeDistribution.flags new file mode 100644 index 0000000000..7806652d4d --- /dev/null +++ b/test/files/run/hashCodeDistribution.flags @@ -0,0 +1 @@ +-Yjenkins-hashCodes \ No newline at end of file diff --git a/test/files/run/hashCodeDistribution.scala b/test/files/run/hashCodeDistribution.scala new file mode 100644 index 0000000000..dbb6e833bd --- /dev/null +++ b/test/files/run/hashCodeDistribution.scala @@ -0,0 +1,17 @@ +// See ticket #2537. +object Test { + case class C(x: Int, y: Int) { } + val COUNT = 300 + val totalCodes = COUNT * COUNT + + def main (args: Array[String]) = { + val hashCodes = + for (x <- 0 until COUNT; y <- 0 until COUNT) yield C(x,y).hashCode + + val uniques = hashCodes.removeDuplicates + val collisionRate = (totalCodes - uniques.size) * 1000 / totalCodes + + assert(collisionRate < 5, "Collision rate too high: %d / 1000".format(collisionRate)) + // println("collisionRate = %d / 1000".format(collisionRate)) + } +} \ No newline at end of file -- cgit v1.2.3