diff options
4 files changed, 200 insertions, 5 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fe932d8ede..a79e64e810 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -823,4 +823,28 @@ private[spark] object Utils extends Logging { return System.getProperties().clone() .asInstanceOf[java.util.Properties].toMap[String, String] } + + /** + * Method executed for repeating a task for side effects. + * Unlike a for comprehension, it permits JVM JIT optimization + */ + def times(numIters: Int)(f: => Unit): Unit = { + var i = 0 + while (i < numIters) { + f + i += 1 + } + } + + /** + * Timing method based on iterations that permit JVM JIT optimization. + * @param numIters number of iterations + * @param f function to be executed + */ + def timeIt(numIters: Int)(f: => Unit): Long = { + val start = System.currentTimeMillis + times(numIters)(f) + System.currentTimeMillis - start + } + } diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala new file mode 100644 index 0000000000..e9907e6c85 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.{Random => JavaRandom} +import org.apache.spark.util.Utils.timeIt + +/** + * This class implements a XORShift random number generator algorithm + * Source: + * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14. + * @see <a href="http://www.jstatsoft.org/v08/i14/paper">Paper</a> + * This implementation is approximately 3.5 times faster than + * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due + * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class + * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG + * for each thread. + */ +private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { + + def this() = this(System.nanoTime) + + private var seed = init + + // we need to just override next - this will be called by nextInt, nextDouble, + // nextGaussian, nextLong, etc. + override protected def next(bits: Int): Int = { + var nextSeed = seed ^ (seed << 21) + nextSeed ^= (nextSeed >>> 35) + nextSeed ^= (nextSeed << 4) + seed = nextSeed + (nextSeed & ((1L << bits) -1)).asInstanceOf[Int] + } +} + +/** Contains benchmark method and main method to run benchmark of the RNG */ +private[spark] object XORShiftRandom { + + /** + * Main method for running benchmark + * @param args takes one argument - the number of random numbers to generate + */ + def main(args: Array[String]): Unit = { + if (args.length != 1) { + println("Benchmark of XORShiftRandom vis-a-vis java.util.Random") + println("Usage: XORShiftRandom number_of_random_numbers_to_generate") + System.exit(1) + } + println(benchmark(args(0).toInt)) + } + + /** + * @param numIters Number of random numbers to generate while running the benchmark + * @return Map of execution times for {@link java.util.Random java.util.Random} + * and XORShift + */ + def benchmark(numIters: Int) = { + + val seed = 1L + val million = 1e6.toInt + val javaRand = new JavaRandom(seed) + val xorRand = new XORShiftRandom(seed) + + // this is just to warm up the JIT - we're not timing anything + timeIt(1e6.toInt) { + javaRand.nextInt() + xorRand.nextInt() + } + + val iters = timeIt(numIters)(_) + + /* Return results as a map instead of just printing to screen + in case the user wants to do something with them */ + Map("javaTime" -> iters {javaRand.nextInt()}, + "xorTime" -> iters {xorRand.nextInt()}) + + } + +}
\ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala new file mode 100644 index 0000000000..b78367b6ca --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.Random +import org.scalatest.FlatSpec +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.Utils.times + +class XORShiftRandomSuite extends FunSuite with ShouldMatchers { + + def fixture = new { + val seed = 1L + val xorRand = new XORShiftRandom(seed) + val hundMil = 1e8.toInt + } + + /* + * This test is based on a chi-squared test for randomness. The values are hard-coded + * so as not to create Spark's dependency on apache.commons.math3 just to call one + * method for calculating the exact p-value for a given number of random numbers + * and bins. In case one would want to move to a full-fledged test based on + * apache.commons.math3, the relevant class is here: + * org.apache.commons.math3.stat.inference.ChiSquareTest + */ + test ("XORShift generates valid random numbers") { + + val f = fixture + + val numBins = 10 + // create 10 bins + val bins = Array.fill(numBins)(0) + + // populate bins based on modulus of the random number + times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1} + + /* since the seed is deterministic, until the algorithm is changed, we know the result will be + * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%) + * significance level. However, should the RNG implementation change, the test should still + * pass at the same significance level. The chi-squared test done in R gave the following + * results: + * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699)) + * Chi-squared test for given probabilities + * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790, + * 10002286, 9998699) + * X-squared = 11.975, df = 9, p-value = 0.2147 + * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million + * random numbers + * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared + * is greater than or equal to that number. + */ + val binSize = f.hundMil/numBins + val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum + xSquared should be < (16.9196) + + } + +}
\ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index edbf77dbcc..0dee9399a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -18,15 +18,16 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer -import scala.util.Random + +import org.jblas.DoubleMatrix import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.util.XORShiftRandom -import org.jblas.DoubleMatrix /** @@ -195,7 +196,7 @@ class KMeans private ( */ private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = { // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq + val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray) } @@ -210,7 +211,7 @@ class KMeans private ( */ private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = { // Initialize each run's center to a random point - val seed = new Random().nextInt() + val seed = new XORShiftRandom().nextInt() val sample = data.takeSample(true, runs, seed).toSeq val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r))) @@ -222,7 +223,7 @@ class KMeans private ( for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point)) }.reduceByKey(_ + _).collectAsMap() val chosen = data.mapPartitionsWithIndex { (index, points) => - val rand = new Random(seed ^ (step << 16) ^ index) + val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) for { p <- points r <- 0 until runs |