aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-07-15 21:02:42 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-15 21:02:42 -0700
commit73d92b00b9a6f5dfc2f8116447d17b381cd74f80 (patch)
treeaa88d4394b9fc2176a7266236a07dd5acceda489
parent6960a7938c61cc07f181ca85e0d8152ceeb453d9 (diff)
downloadspark-73d92b00b9a6f5dfc2f8116447d17b381cd74f80.tar.gz
spark-73d92b00b9a6f5dfc2f8116447d17b381cd74f80.tar.bz2
spark-73d92b00b9a6f5dfc2f8116447d17b381cd74f80.zip
[SPARK-9018] [MLLIB] add stopwatches
Add stopwatches for easy instrumentation of MLlib algorithms. This is based on the `TimeTracker` used in decision trees. The distributed version uses Spark accumulator. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #7415 from mengxr/SPARK-9018 and squashes the following commits: 40b4347 [Xiangrui Meng] == -> === c477745 [Xiangrui Meng] address Joseph's comments f981a49 [Xiangrui Meng] add stopwatches
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala151
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala109
2 files changed, 260 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
new file mode 100644
index 0000000000..5fdf878a3d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/stopwatches.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import scala.collection.mutable
+
+import org.apache.spark.{Accumulator, SparkContext}
+
+/**
+ * Abstract class for stopwatches.
+ */
+private[spark] abstract class Stopwatch extends Serializable {
+
+ @transient private var running: Boolean = false
+ private var startTime: Long = _
+
+ /**
+ * Name of the stopwatch.
+ */
+ val name: String
+
+ /**
+ * Starts the stopwatch.
+ * Throws an exception if the stopwatch is already running.
+ */
+ def start(): Unit = {
+ assume(!running, "start() called but the stopwatch is already running.")
+ running = true
+ startTime = now
+ }
+
+ /**
+ * Stops the stopwatch and returns the duration of the last session in milliseconds.
+ * Throws an exception if the stopwatch is not running.
+ */
+ def stop(): Long = {
+ assume(running, "stop() called but the stopwatch is not running.")
+ val duration = now - startTime
+ add(duration)
+ running = false
+ duration
+ }
+
+ /**
+ * Checks whether the stopwatch is running.
+ */
+ def isRunning: Boolean = running
+
+ /**
+ * Returns total elapsed time in milliseconds, not counting the current session if the stopwatch
+ * is running.
+ */
+ def elapsed(): Long
+
+ /**
+ * Gets the current time in milliseconds.
+ */
+ protected def now: Long = System.currentTimeMillis()
+
+ /**
+ * Adds input duration to total elapsed time.
+ */
+ protected def add(duration: Long): Unit
+}
+
+/**
+ * A local [[Stopwatch]].
+ */
+private[spark] class LocalStopwatch(override val name: String) extends Stopwatch {
+
+ private var elapsedTime: Long = 0L
+
+ override def elapsed(): Long = elapsedTime
+
+ override protected def add(duration: Long): Unit = {
+ elapsedTime += duration
+ }
+}
+
+/**
+ * A distributed [[Stopwatch]] using Spark accumulator.
+ * @param sc SparkContext
+ */
+private[spark] class DistributedStopwatch(
+ sc: SparkContext,
+ override val name: String) extends Stopwatch {
+
+ private val elapsedTime: Accumulator[Long] = sc.accumulator(0L, s"DistributedStopwatch($name)")
+
+ override def elapsed(): Long = elapsedTime.value
+
+ override protected def add(duration: Long): Unit = {
+ elapsedTime += duration
+ }
+}
+
+/**
+ * A multiple stopwatch that contains local and distributed stopwatches.
+ * @param sc SparkContext
+ */
+private[spark] class MultiStopwatch(@transient private val sc: SparkContext) extends Serializable {
+
+ private val stopwatches: mutable.Map[String, Stopwatch] = mutable.Map.empty
+
+ /**
+ * Adds a local stopwatch.
+ * @param name stopwatch name
+ */
+ def addLocal(name: String): this.type = {
+ require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
+ stopwatches(name) = new LocalStopwatch(name)
+ this
+ }
+
+ /**
+ * Adds a distributed stopwatch.
+ * @param name stopwatch name
+ */
+ def addDistributed(name: String): this.type = {
+ require(!stopwatches.contains(name), s"Stopwatch with name $name already exists.")
+ stopwatches(name) = new DistributedStopwatch(sc, name)
+ this
+ }
+
+ /**
+ * Gets a stopwatch.
+ * @param name stopwatch name
+ */
+ def apply(name: String): Stopwatch = stopwatches(name)
+
+ override def toString: String = {
+ stopwatches.values.toArray.sortBy(_.name)
+ .map(c => s" ${c.name}: ${c.elapsed()}ms")
+ .mkString("{\n", ",\n", "\n}")
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
new file mode 100644
index 0000000000..8df6617fe0
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/StopwatchSuite.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class StopwatchSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ private def testStopwatchOnDriver(sw: Stopwatch): Unit = {
+ assert(sw.name === "sw")
+ assert(sw.elapsed() === 0L)
+ assert(!sw.isRunning)
+ intercept[AssertionError] {
+ sw.stop()
+ }
+ sw.start()
+ Thread.sleep(50)
+ val duration = sw.stop()
+ assert(duration >= 50 && duration < 100) // using a loose upper bound
+ val elapsed = sw.elapsed()
+ assert(elapsed === duration)
+ sw.start()
+ Thread.sleep(50)
+ val duration2 = sw.stop()
+ assert(duration2 >= 50 && duration2 < 100)
+ val elapsed2 = sw.elapsed()
+ assert(elapsed2 === duration + duration2)
+ sw.start()
+ assert(sw.isRunning)
+ intercept[AssertionError] {
+ sw.start()
+ }
+ }
+
+ test("LocalStopwatch") {
+ val sw = new LocalStopwatch("sw")
+ testStopwatchOnDriver(sw)
+ }
+
+ test("DistributedStopwatch on driver") {
+ val sw = new DistributedStopwatch(sc, "sw")
+ testStopwatchOnDriver(sw)
+ }
+
+ test("DistributedStopwatch on executors") {
+ val sw = new DistributedStopwatch(sc, "sw")
+ val rdd = sc.parallelize(0 until 4, 4)
+ rdd.foreach { i =>
+ sw.start()
+ Thread.sleep(50)
+ sw.stop()
+ }
+ assert(!sw.isRunning)
+ val elapsed = sw.elapsed()
+ assert(elapsed >= 200 && elapsed < 400) // using a loose upper bound
+ }
+
+ test("MultiStopwatch") {
+ val sw = new MultiStopwatch(sc)
+ .addLocal("local")
+ .addDistributed("spark")
+ assert(sw("local").name === "local")
+ assert(sw("spark").name === "spark")
+ intercept[NoSuchElementException] {
+ sw("some")
+ }
+ assert(sw.toString === "{\n local: 0ms,\n spark: 0ms\n}")
+ sw("local").start()
+ sw("spark").start()
+ Thread.sleep(50)
+ sw("local").stop()
+ Thread.sleep(50)
+ sw("spark").stop()
+ val localElapsed = sw("local").elapsed()
+ val sparkElapsed = sw("spark").elapsed()
+ assert(localElapsed >= 50 && localElapsed < 100)
+ assert(sparkElapsed >= 100 && sparkElapsed < 200)
+ assert(sw.toString ===
+ s"{\n local: ${localElapsed}ms,\n spark: ${sparkElapsed}ms\n}")
+ val rdd = sc.parallelize(0 until 4, 4)
+ rdd.foreach { i =>
+ sw("local").start()
+ sw("spark").start()
+ Thread.sleep(50)
+ sw("spark").stop()
+ sw("local").stop()
+ }
+ val localElapsed2 = sw("local").elapsed()
+ assert(localElapsed2 === localElapsed)
+ val sparkElapsed2 = sw("spark").elapsed()
+ assert(sparkElapsed2 >= 300 && sparkElapsed2 < 600)
+ }
+}