aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorImran Rashid <irashid@cloudera.com>2015-11-06 20:06:24 +0000
committerSean Owen <sowen@cloudera.com>2015-11-06 20:06:24 +0000
commit49f1a820372d1cba41f3f00d07eb5728f2ed6705 (patch)
tree535797cc3662bfd7d8247b2d01f6fd00b2e1b2a9 /core
parent62bb290773c9f9fa53cbe6d4eedc6e153761a763 (diff)
downloadspark-49f1a820372d1cba41f3f00d07eb5728f2ed6705.tar.gz
spark-49f1a820372d1cba41f3f00d07eb5728f2ed6705.tar.bz2
spark-49f1a820372d1cba41f3f00d07eb5728f2ed6705.zip
[SPARK-10116][CORE] XORShiftRandom.hashSeed is random in high bits
https://issues.apache.org/jira/browse/SPARK-10116 This is really trivial, just happened to notice it -- if `XORShiftRandom.hashSeed` is really supposed to have random bits throughout (as the comment implies), it needs to do something for the conversion to `long`. mengxr mkolod Author: Imran Rashid <irashid@cloudera.com> Closes #8314 from squito/SPARK-10116.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala6
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java20
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala52
-rw-r--r--core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala15
4 files changed, 70 insertions, 23 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
index 85fb923cd9..e8cdb6e98b 100644
--- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
@@ -60,9 +60,11 @@ private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
private[spark] object XORShiftRandom {
/** Hash seeds to have 0/1 bits throughout. */
- private def hashSeed(seed: Long): Long = {
+ private[random] def hashSeed(seed: Long): Long = {
val bytes = ByteBuffer.allocate(java.lang.Long.SIZE).putLong(seed).array()
- MurmurHash3.bytesHash(bytes)
+ val lowBits = MurmurHash3.bytesHash(bytes)
+ val highBits = MurmurHash3.bytesHash(bytes, lowBits)
+ (highBits.toLong << 32) | (lowBits.toLong & 0xFFFFFFFFL)
}
/**
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index fd8f7f39b7..4d4e982050 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -146,21 +146,29 @@ public class JavaAPISuite implements Serializable {
public void sample() {
List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
JavaRDD<Integer> rdd = sc.parallelize(ints);
- JavaRDD<Integer> sample20 = rdd.sample(true, 0.2, 3);
+ // the seeds here are "magic" to make this work out nicely
+ JavaRDD<Integer> sample20 = rdd.sample(true, 0.2, 8);
Assert.assertEquals(2, sample20.count());
- JavaRDD<Integer> sample20WithoutReplacement = rdd.sample(false, 0.2, 5);
+ JavaRDD<Integer> sample20WithoutReplacement = rdd.sample(false, 0.2, 2);
Assert.assertEquals(2, sample20WithoutReplacement.count());
}
@Test
public void randomSplit() {
- List<Integer> ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
+ List<Integer> ints = new ArrayList<>(1000);
+ for (int i = 0; i < 1000; i++) {
+ ints.add(i);
+ }
JavaRDD<Integer> rdd = sc.parallelize(ints);
JavaRDD<Integer>[] splits = rdd.randomSplit(new double[] { 0.4, 0.6, 1.0 }, 31);
+ // the splits aren't perfect -- not enough data for them to be -- just check they're about right
Assert.assertEquals(3, splits.length);
- Assert.assertEquals(1, splits[0].count());
- Assert.assertEquals(2, splits[1].count());
- Assert.assertEquals(7, splits[2].count());
+ long s0 = splits[0].count();
+ long s1 = splits[1].count();
+ long s2 = splits[2].count();
+ Assert.assertTrue(s0 + " not within expected range", s0 > 150 && s0 < 250);
+ Assert.assertTrue(s1 + " not within expected range", s1 > 250 && s0 < 350);
+ Assert.assertTrue(s2 + " not within expected range", s2 > 430 && s2 < 570);
}
@Test
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 1321ec8473..7d2cfcca94 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.rdd
+import org.apache.commons.math3.distribution.{PoissonDistribution, BinomialDistribution}
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.mapred._
import org.apache.hadoop.util.Progressable
@@ -578,17 +579,36 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
(x: Int) => if (x % 10 < (10 * fractionPositive).toInt) "1" else "0"
}
- def checkSize(exact: Boolean,
- withReplacement: Boolean,
- expected: Long,
- actual: Long,
- p: Double): Boolean = {
+ def assertBinomialSample(
+ exact: Boolean,
+ actual: Int,
+ trials: Int,
+ p: Double): Unit = {
+ if (exact) {
+ assert(actual == math.ceil(p * trials).toInt)
+ } else {
+ val dist = new BinomialDistribution(trials, p)
+ val q = dist.cumulativeProbability(actual)
+ withClue(s"p = $p: trials = $trials") {
+ assert(q >= 0.001 && q <= 0.999)
+ }
+ }
+ }
+
+ def assertPoissonSample(
+ exact: Boolean,
+ actual: Int,
+ trials: Int,
+ p: Double): Unit = {
if (exact) {
- return expected == actual
+ assert(actual == math.ceil(p * trials).toInt)
+ } else {
+ val dist = new PoissonDistribution(p * trials)
+ val q = dist.cumulativeProbability(actual)
+ withClue(s"p = $p: trials = $trials") {
+ assert(q >= 0.001 && q <= 0.999)
+ }
}
- val stdev = if (withReplacement) math.sqrt(expected) else math.sqrt(expected * p * (1 - p))
- // Very forgiving margin since we're dealing with very small sample sizes most of the time
- math.abs(actual - expected) <= 6 * stdev
}
def testSampleExact(stratifiedData: RDD[(String, Int)],
@@ -613,8 +633,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
samplingRate: Double,
seed: Long,
n: Long): Unit = {
- val expectedSampleSize = stratifiedData.countByKey()
- .mapValues(count => math.ceil(count * samplingRate).toInt)
+ val trials = stratifiedData.countByKey()
val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
val sample = if (exact) {
stratifiedData.sampleByKeyExact(false, fractions, seed)
@@ -623,8 +642,10 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
}
val sampleCounts = sample.countByKey()
val takeSample = sample.collect()
- sampleCounts.foreach { case(k, v) =>
- assert(checkSize(exact, false, expectedSampleSize(k), v, samplingRate)) }
+ sampleCounts.foreach { case (k, v) =>
+ assertBinomialSample(exact = exact, actual = v.toInt, trials = trials(k).toInt,
+ p = samplingRate)
+ }
assert(takeSample.size === takeSample.toSet.size)
takeSample.foreach { x => assert(1 <= x._2 && x._2 <= n, s"elements not in [1, $n]") }
}
@@ -635,6 +656,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
samplingRate: Double,
seed: Long,
n: Long): Unit = {
+ val trials = stratifiedData.countByKey()
val expectedSampleSize = stratifiedData.countByKey().mapValues(count =>
math.ceil(count * samplingRate).toInt)
val fractions = Map("1" -> samplingRate, "0" -> samplingRate)
@@ -646,7 +668,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
val sampleCounts = sample.countByKey()
val takeSample = sample.collect()
sampleCounts.foreach { case (k, v) =>
- assert(checkSize(exact, true, expectedSampleSize(k), v, samplingRate))
+ assertPoissonSample(exact, actual = v.toInt, trials = trials(k).toInt, p = samplingRate)
}
val groupedByKey = takeSample.groupBy(_._1)
for ((key, v) <- groupedByKey) {
@@ -657,7 +679,7 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext {
if (exact) {
assert(v.toSet.size <= expectedSampleSize(key))
} else {
- assert(checkSize(false, true, expectedSampleSize(key), v.toSet.size, samplingRate))
+ assertPoissonSample(false, actual = v.toSet.size, trials(key).toInt, p = samplingRate)
}
}
}
diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
index d26667bf72..a5b50fce5c 100644
--- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
@@ -65,4 +65,19 @@ class XORShiftRandomSuite extends SparkFunSuite with Matchers {
val random = new XORShiftRandom(0L)
assert(random.nextInt() != 0)
}
+
+ test ("hashSeed has random bits throughout") {
+ val totalBitCount = (0 until 10).map { seed =>
+ val hashed = XORShiftRandom.hashSeed(seed)
+ val bitCount = java.lang.Long.bitCount(hashed)
+ // make sure we have roughly equal numbers of 0s and 1s. Mostly just check that we
+ // don't have all 0s or 1s in the high bits
+ bitCount should be > 20
+ bitCount should be < 44
+ bitCount
+ }.sum
+ // and over all the seeds, very close to equal numbers of 0s & 1s
+ totalBitCount should be > (32 * 10 - 30)
+ totalBitCount should be < (32 * 10 + 30)
+ }
}