aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala64
2 files changed, 43 insertions, 25 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
index 5fdf878a3d..8d4174124b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
@@ -67,6 +67,8 @@ private[spark] abstract class Stopwatch extends Serializable {
*/
def elapsed(): Long
+ override def toString: String = s"$name: ${elapsed()}ms"
+
/**
* Gets the current time in milliseconds.
*/
@@ -145,7 +147,7 @@ private[spark] class MultiStopwatch(@transient private val sc: SparkContext) ext
override def toString: String = {
stopwatches.values.toArray.sortBy(_.name)
- .map(c => s" ${c.name}: ${c.elapsed()}ms")
+ .map(c => s" $c")
.mkString("{\n", ",\n", "\n}")
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
index 8df6617fe0..9e6bc7193c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
@@ -17,11 +17,15 @@
package org.apache.spark.ml.util
+import java.util.Random
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
+ import StopwatchSuite._
+
private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
assert(sw.name === "sw")
assert(sw.elapsed() === 0L)
@@ -29,18 +33,13 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
intercept[AssertionError] {
sw.stop()
}
- sw.start()
- Thread.sleep(50)
- val duration = sw.stop()
- assert(duration >= 50 && duration < 100) // using a loose upper bound
+ val duration = checkStopwatch(sw)
val elapsed = sw.elapsed()
assert(elapsed === duration)
- sw.start()
- Thread.sleep(50)
- val duration2 = sw.stop()
- assert(duration2 >= 50 && duration2 < 100)
+ val duration2 = checkStopwatch(sw)
val elapsed2 = sw.elapsed()
assert(elapsed2 === duration + duration2)
+ assert(sw.toString === s"sw: ${elapsed2}ms")
sw.start()
assert(sw.isRunning)
intercept[AssertionError] {
@@ -61,14 +60,13 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
test("DistributedStopwatch on executors") {
val sw = new DistributedStopwatch(sc, "sw")
val rdd = sc.parallelize(0 until 4, 4)
+ val acc = sc.accumulator(0L)
rdd.foreach { i =>
- sw.start()
- Thread.sleep(50)
- sw.stop()
+ acc += checkStopwatch(sw)
}
assert(!sw.isRunning)
val elapsed = sw.elapsed()
- assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound
+ assert(elapsed === acc.value)
}
test("MultiStopwatch") {
@@ -81,29 +79,47 @@ class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
sw("some")
}
assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}")
- sw("local").start()
- sw("spark").start()
- Thread.sleep(50)
- sw("local").stop()
- Thread.sleep(50)
- sw("spark").stop()
+ val localDuration = checkStopwatch(sw("local"))
+ val sparkDuration = checkStopwatch(sw("spark"))
val localElapsed = sw("local").elapsed()
val sparkElapsed = sw("spark").elapsed()
- assert(localElapsed >= 50 && localElapsed < 100)
- assert(sparkElapsed >= 100 && sparkElapsed < 200)
+ assert(localElapsed === localDuration)
+ assert(sparkElapsed === sparkDuration)
assert(sw.toString ===
s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}")
val rdd = sc.parallelize(0 until 4, 4)
+ val acc = sc.accumulator(0L)
rdd.foreach { i =>
sw("local").start()
- sw("spark").start()
- Thread.sleep(50)
- sw("spark").stop()
+ val duration = checkStopwatch(sw("spark"))
sw("local").stop()
+ acc += duration
}
val localElapsed2 = sw("local").elapsed()
assert(localElapsed2 === localElapsed)
val sparkElapsed2 = sw("spark").elapsed()
- assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600)
+ assert(sparkElapsed2 === sparkElapsed + acc.value)
}
}
+
+private object StopwatchSuite extends SparkFunSuite {
+
+ /**
+ * Checks the input stopwatch on a task that takes a random time (<10ms) to finish. Validates and
+ * returns the duration reported by the stopwatch.
+ */
+ def checkStopwatch(sw: Stopwatch): Long = {
+ val ubStart = now
+ sw.start()
+ val lbStart = now
+ Thread.sleep(new Random().nextInt(10))
+ val lb = now - lbStart
+ val duration = sw.stop()
+ val ub = now - ubStart
+ assert(duration >= lb && duration <= ub)
+ duration
+ }
+
+ /** The current time in milliseconds. */
+ private def now: Long = System.currentTimeMillis()
+}