aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMarek Kolodziej <mkolod@gmail.com>2013-11-18 22:00:36 -0500
committerMarek Kolodziej <mkolod@gmail.com>2013-11-18 22:00:36 -0500
commit99cfe89c688ee1499d2723d8ea909651995abe86 (patch)
tree3548a077a71cbb120b195d66cb241e3b1baa31c4 /core
parent09bdfe3b163559fdcf8771b52ffbe2542883c912 (diff)
downloadspark-99cfe89c688ee1499d2723d8ea909651995abe86.tar.gz
spark-99cfe89c688ee1499d2723d8ea909651995abe86.tar.bz2
spark-99cfe89c688ee1499d2723d8ea909651995abe86.zip
Updates to reflect pull request code review
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala43
-rw-r--r--core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala55
-rw-r--r--core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala10
4 files changed, 66 insertions, 46 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index dd9c32f253..e738bfbdc2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -17,8 +17,6 @@
package org.apache.spark.rdd
-import org.apache.spark.util.{XORShiftRandom => Random}
-
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
@@ -38,7 +36,7 @@ import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{Utils, BoundedPriorityQueue}
+import org.apache.spark.util.{Utils, BoundedPriorityQueue, XORShiftRandom => Random}
import org.apache.spark.SparkContext._
import org.apache.spark._
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 2df7108d31..b98a81053d 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -818,42 +818,33 @@ private[spark] object Utils extends Logging {
hashAbs
}
- /* Returns a copy of the system properties that is thread-safe to iterator over. */
+ /** Returns a copy of the system properties that is thread-safe to iterator over. */
def getSystemProperties(): Map[String, String] = {
return System.getProperties().clone()
.asInstanceOf[java.util.Properties].toMap[String, String]
}
- /* Used for performance tersting along with the intToTimesInt() and timeIt methods
- * It uses a while loop instead of a for comprehension since the JIT will
- * optimize the while loop better than the "for" closure
- * e.g.
- * import org.apache.spark.util.Utils.{TimesInt, intToTimesInt, timeIt}
- * import java.util.Random
- * val rand = new Random()
- * timeIt(rand.nextDouble, 10000000)
+ /**
+ * Method executed for repeating a task for side effects.
+ * Unlike a for comprehension, it permits JVM JIT optimization
*/
- class TimesInt(i: Int) {
- def times(f: => Unit) = {
- var x = 1
- while (x <= i) {
- f
- x += 1
+ def times(numIters: Int)(f: => Unit): Unit = {
+ var i = 0
+ while (i < numIters) {
+ f
+ i += 1
}
- }
}
-
- /* Used in conjunction with TimesInt since it's Scala 2.9.3
- * instead of 2.10 and we don't have implicit classes */
- implicit def intToTimesInt(i: Int) = new TimesInt(i)
-
- /* See TimesInt for use example */
- def timeIt(f: => Unit, iters: Int): Long = {
+ /**
+ * Timing method based on iterations that permit JVM JIT optimization.
+ * @param numIters number of iterations
+ * @param f function to be executed
+ */
+ def timeIt(numIters: Int)(f: => Unit): Long = {
val start = System.currentTimeMillis
- iters.times(f)
+ times(numIters)(f)
System.currentTimeMillis - start
-
}
-
+
}
diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
index 3c189c1b69..d443595c24 100644
--- a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
@@ -18,18 +18,28 @@
package org.apache.spark.util
import java.util.{Random => JavaRandom}
-import Utils.{TimesInt, intToTimesInt, timeIt}
+import org.apache.spark.util.Utils.timeIt
+/**
+ * This class implements a XORShift random number generator algorithm
+ * Source:
+ * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14.
+ * @see <a href="http://www.jstatsoft.org/v08/i14/paper">Paper</a>
+ * This implementation is approximately 3.5 times faster than
+ * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due
+ * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class
+ * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG
+ * for each thread.
+ */
class XORShiftRandom(init: Long) extends JavaRandom(init) {
def this() = this(System.nanoTime)
- var seed = init
+ private var seed = init
// we need to just override next - this will be called by nextInt, nextDouble,
// nextGaussian, nextLong, etc.
- override protected def next(bits: Int): Int = {
-
+ override protected def next(bits: Int): Int = {
var nextSeed = seed ^ (seed << 21)
nextSeed ^= (nextSeed >>> 35)
nextSeed ^= (nextSeed << 4)
@@ -38,25 +48,46 @@ class XORShiftRandom(init: Long) extends JavaRandom(init) {
}
}
+/** Contains benchmark method and main method to run benchmark of the RNG */
object XORShiftRandom {
+ /**
+ * Main method for running benchmark
+ * @param args takes one argument - the number of random numbers to generate
+ */
+ def main(args: Array[String]): Unit = {
+ if (args.length != 1) {
+ println("Benchmark of XORShiftRandom vis-a-vis java.util.Random")
+ println("Usage: XORShiftRandom number_of_random_numbers_to_generate")
+ System.exit(1)
+ }
+ println(benchmark(args(0).toInt))
+ }
+
+ /**
+ * @param numIters Number of random numbers to generate while running the benchmark
+ * @return Map of execution times for {@link java.util.Random java.util.Random}
+ * and XORShift
+ */
def benchmark(numIters: Int) = {
val seed = 1L
val million = 1e6.toInt
val javaRand = new JavaRandom(seed)
val xorRand = new XORShiftRandom(seed)
-
- // warm up the JIT
- million.times {
- javaRand.nextInt
- xorRand.nextInt
+
+ // this is just to warm up the JIT - we're not timing anything
+ timeIt(1e6.toInt) {
+ javaRand.nextInt()
+ xorRand.nextInt()
}
+ val iters = timeIt(numIters)(_)
+
/* Return results as a map instead of just printing to screen
- in case the user wants to do something with them */
- Map("javaTime" -> timeIt(javaRand.nextInt, numIters),
- "xorTime" -> timeIt(xorRand.nextInt, numIters))
+ in case the user wants to do something with them */
+ Map("javaTime" -> iters {javaRand.nextInt()},
+ "xorTime" -> iters {xorRand.nextInt()})
}
diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
index 1691cb4f01..b78367b6ca 100644
--- a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
@@ -21,7 +21,7 @@ import java.util.Random
import org.scalatest.FlatSpec
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.apache.spark.util.Utils.{TimesInt, intToTimesInt, timeIt}
+import org.apache.spark.util.Utils.times
class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
@@ -48,7 +48,7 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
val bins = Array.fill(numBins)(0)
// populate bins based on modulus of the random number
- f.hundMil.times(bins(math.abs(f.xorRand.nextInt) % 10) += 1)
+ times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1}
/* since the seed is deterministic, until the algorithm is changed, we know the result will be
* exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
@@ -67,9 +67,9 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
* and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared
* is greater than or equal to that number.
*/
- val binSize = f.hundMil/numBins
- val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum
- xSquared should be < (16.9196)
+ val binSize = f.hundMil/numBins
+ val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum
+ xSquared should be < (16.9196)
}