diff options
author | Imran Rashid <irashid@cloudera.com> | 2015-11-06 20:06:24 +0000 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2015-11-06 20:06:24 +0000 |
commit | 49f1a820372d1cba41f3f00d07eb5728f2ed6705 (patch) | |
tree | 535797cc3662bfd7d8247b2d01f6fd00b2e1b2a9 /core | |
parent | 62bb290773c9f9fa53cbe6d4eedc6e153761a763 (diff) | |
download | spark-49f1a820372d1cba41f3f00d07eb5728f2ed6705.tar.gz spark-49f1a820372d1cba41f3f00d07eb5728f2ed6705.tar.bz2 spark-49f1a820372d1cba41f3f00d07eb5728f2ed6705.zip |
[SPARK-10116][CORE] XORShiftRandom.hashSeed is random in high bits
https://issues.apache.org/jira/browse/SPARK-10116
This is really trivial, just happened to notice it -- if `XORShiftRandom.hashSeed` is really supposed to have random bits throughout (as the comment implies), it needs to do something for the conversion to `long`.
mengxr mkolod
Author: Imran Rashid <irashid@cloudera.com>
Closes #8314 from squito/SPARK-10116.
Diffstat (limited to 'core')
4 files changed, 70 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala index 85fb923cd9..e8cdb6e98b 100644 --- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala @@ -60,9 +60,11 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { private[spark] object XORShiftRandom { /** Hash seeds to have 0/1 bits throughout. */ - private def hashSeed(seed: Long): Long = { + private[random] def hashSeed(seed: Long): Long = { val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array() - MurmurHash3.bytesHash(bytes) + val lowBits = MurmurHash3.bytesHash(bytes) + val highBits = MurmurHash3.bytesHash(bytes, lowBits) + (highBits.toLong << 32) | (lowBits.toLong & 0xFFFFFFFFL) } /** diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index fd8f7f39b7..4d4e982050 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -146,21 +146,29 @@ public class JavaAPISuite implements Serializable { public void sample() { List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); JavaRDD<Integer> rdd = sc.parallelize(ints); - JavaRDD<Integer> sample20 = rdd.sample(true, 0.2, 3); + // the seeds here are "magic" to make this work out nicely + JavaRDD<Integer> sample20 = rdd.sample(true, 0.2, 8); Assert.assertEquals(2, sample20.count()); - JavaRDD<Integer> sample20WithoutReplacement = rdd.sample(false, 0.2, 5); + JavaRDD<Integer> sample20WithoutReplacement = rdd.sample(false, 0.2, 2); Assert.assertEquals(2, sample20WithoutReplacement.count()); } @Test public void randomSplit() { - List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10); + List<Integer> ints = new ArrayList<>(1000); + for (int i = 0; i < 1000; i++) { + ints.add(i); + } JavaRDD<Integer> rdd = sc.parallelize(ints); JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31); + // the splits aren't perfect -- not enough data for them to be -- just check they're about right Assert.assertEquals(3, splits.length); - Assert.assertEquals(1, splits[0].count()); - Assert.assertEquals(2, splits[1].count()); - Assert.assertEquals(7, splits[2].count()); + long s0 = splits[0].count(); + long s1 = splits[1].count(); + long s2 = splits[2].count(); + Assert.assertTrue(s0 + " not within expected range", s0 > 150 && s0 < 250); + Assert.assertTrue(s1 + " not within expected range", s1 > 250 && s0 < 350); + Assert.assertTrue(s2 + " not within expected range", s2 > 430 && s2 < 570); } @Test diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 1321ec8473..7d2cfcca94 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import org.apache.commons.math3.distribution.{PoissonDistribution, BinomialDistribution} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ import org.apache.hadoop.util.Progressable @@ -578,17 +579,36 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { (x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0" } - def checkSize(exact: Boolean, - withReplacement: Boolean, - expected: Long, - actual: Long, - p: Double): Boolean = { + def assertBinomialSample( + exact: Boolean, + actual: Int, + trials: Int, + p: Double): Unit = { + if (exact) { + assert(actual == math.ceil(p * trials).toInt) + } else { + val dist = new BinomialDistribution(trials, p) + val q = dist.cumulativeProbability(actual) + withClue(s"p = $p: trials = $trials") { + assert(q >= 0.001 && q <= 0.999) + } + } + } + + def assertPoissonSample( + exact: Boolean, + actual: Int, + trials: Int, + p: Double): Unit = { if (exact) { - return expected == actual + assert(actual == math.ceil(p * trials).toInt) + } else { + val dist = new PoissonDistribution(p * trials) + val q = dist.cumulativeProbability(actual) + withClue(s"p = $p: trials = $trials") { + assert(q >= 0.001 && q <= 0.999) + } } - val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p)) - // Very forgiving margin since we're dealing with very small sample sizes most of the time - math.abs(actual - expected) <= 6 * stdev } def testSampleExact(stratifiedData: RDD[(String, Int)], @@ -613,8 +633,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { samplingRate: Double, seed: Long, n: Long): Unit = { - val expectedSampleSize = stratifiedData.countByKey() - .mapValues(count => math.ceil(count * samplingRate).toInt) + val trials = stratifiedData.countByKey() val fractions = Map("1" -> samplingRate, "0" -> samplingRate) val sample = if (exact) { stratifiedData.sampleByKeyExact(false, fractions, seed) @@ -623,8 +642,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { } val sampleCounts = sample.countByKey() val takeSample = sample.collect() - sampleCounts.foreach { case(k, v) => - assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) } + sampleCounts.foreach { case (k, v) => + assertBinomialSample(exact = exact, actual = v.toInt, trials = trials(k).toInt, + p = samplingRate) + } assert(takeSample.size === takeSample.toSet.size) takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") } } @@ -635,6 +656,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { samplingRate: Double, seed: Long, n: Long): Unit = { + val trials = stratifiedData.countByKey() val expectedSampleSize = stratifiedData.countByKey().mapValues(count => math.ceil(count * samplingRate).toInt) val fractions = Map("1" -> samplingRate, "0" -> samplingRate) @@ -646,7 +668,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { val sampleCounts = sample.countByKey() val takeSample = sample.collect() sampleCounts.foreach { case (k, v) => - assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate)) + assertPoissonSample(exact, actual = v.toInt, trials = trials(k).toInt, p = samplingRate) } val groupedByKey = takeSample.groupBy(_._1) for ((key, v) <- groupedByKey) { @@ -657,7 +679,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { if (exact) { assert(v.toSet.size <= expectedSampleSize(key)) } else { - assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate)) + assertPoissonSample(false, actual = v.toSet.size, trials(key).toInt, p = samplingRate) } } } diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index d26667bf72..a5b50fce5c 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -65,4 +65,19 @@ class XORShiftRandomSuite extends SparkFunSuite with Matchers { val random = new XORShiftRandom(0L) assert(random.nextInt() != 0) } + + test ("hashSeed has random bits throughout") { + val totalBitCount = (0 until 10).map { seed => + val hashed = XORShiftRandom.hashSeed(seed) + val bitCount = java.lang.Long.bitCount(hashed) + // make sure we have roughly equal numbers of 0s and 1s. Mostly just check that we + // don't have all 0s or 1s in the high bits + bitCount should be > 20 + bitCount should be < 44 + bitCount + }.sum + // and over all the seeds, very close to equal numbers of 0s & 1s + totalBitCount should be > (32 * 10 - 30) + totalBitCount should be < (32 * 10 + 30) + } } |