aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-07-30 07:56:15 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-30 07:56:15 -0700
commitc5815930be46a89469440b7c61b59764fb67a54c (patch)
treea43746d0ea0824ae54d6cbb7eea305874ef02c52 /mllib/src/test
parentd31c618e3c8838f8198556876b9dcbbbf835f7b2 (diff)
downloadspark-c5815930be46a89469440b7c61b59764fb67a54c.tar.gz
spark-c5815930be46a89469440b7c61b59764fb67a54c.tar.bz2
spark-c5815930be46a89469440b7c61b59764fb67a54c.zip
[SPARK-5561] [MLLIB] Generalized PeriodicCheckpointer for RDDs and Graphs
PeriodicGraphCheckpointer was introduced for Latent Dirichlet Allocation (LDA), but it was meant to be generalized to work with Graphs, RDDs, and other data structures based on RDDs. This PR generalizes it. For those who are not familiar with the periodic checkpointer, it tries to automatically handle persisting/unpersisting and checkpointing/removing checkpoint files in a lineage of RDD-based objects. I need it generalized to use with GradientBoostedTrees [https://issues.apache.org/jira/browse/SPARK-6684]. It should be useful for other iterative algorithms as well. Changes I made: * Copied PeriodicGraphCheckpointer to PeriodicCheckpointer. * Within PeriodicCheckpointer, I created abstract methods for the basic operations (checkpoint, persist, etc.). * The subclasses for Graphs and RDDs implement those abstract methods. * I copied the test suite for the graph checkpointer and made tiny modifications to make it work for RDDs. To review this PR, I recommend doing 2 diffs: (1) diff between the old PeriodicGraphCheckpointer.scala and the new PeriodicCheckpointer.scala (2) diff between the 2 test suites CCing andrewor14 in case there are relevant changes to checkpointing. CCing feynmanliang in case you're interested in learning about checkpointing. CCing mengxr for final OK. Thanks all! Author: Joseph K. Bradley <joseph@databricks.com> Closes #7728 from jkbradley/gbt-checkpoint and squashes the following commits: d41902c [Joseph K. Bradley] Oops, forgot to update an extra time in the checkpointer tests, after the last commit. I'll fix that. I'll also make some of the checkpointer methods protected, which I should have done before. 32b23b8 [Joseph K. Bradley] fixed usage of checkpointer in lda 0b3dbc0 [Joseph K. Bradley] Changed checkpointer constructor not to take initial data. 568918c [Joseph K. Bradley] Generalized PeriodicGraphCheckpointer to PeriodicCheckpointer, with subclasses for RDDs and Graphs.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala173
2 files changed, 182 insertions, 7 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
index d34888af2d..e331c75989 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -30,20 +30,20 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
import PeriodicGraphCheckpointerSuite._
- // TODO: Do I need to call count() on the graphs' RDDs?
-
test("Persisting") {
var graphsToCheck = Seq.empty[GraphToCheck]
val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer(graph1, 10)
+ val checkpointer =
+ new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
checkPersistence(graphsToCheck, 1)
var iteration = 2
while (iteration < 9) {
val graph = createGraph(sc)
- checkpointer.updateGraph(graph)
+ checkpointer.update(graph)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
checkPersistence(graphsToCheck, iteration)
iteration += 1
@@ -57,7 +57,9 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
var graphsToCheck = Seq.empty[GraphToCheck]
sc.setCheckpointDir(path)
val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval)
+ val checkpointer = new PeriodicGraphCheckpointer[Double, Double](
+ checkpointInterval, graph1.vertices.sparkContext)
+ checkpointer.update(graph1)
graph1.edges.count()
graph1.vertices.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
@@ -66,7 +68,7 @@ class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkCo
var iteration = 2
while (iteration < 9) {
val graph = createGraph(sc)
- checkpointer.updateGraph(graph)
+ checkpointer.update(graph)
graph.vertices.count()
graph.edges.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
@@ -168,7 +170,7 @@ private object PeriodicGraphCheckpointerSuite {
} else {
// Graph should never be checkpointed
assert(!graph.isCheckpointed, "Graph should never have been checkpointed")
- assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files")
+ assert(graph.getCheckpointFiles.isEmpty, "Graph should not have any checkpoint files")
}
} catch {
case e: AssertionError =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
new file mode 100644
index 0000000000..b2a459a68b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.apache.spark.{SparkContext, SparkFunSuite}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+
+class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ import PeriodicRDDCheckpointerSuite._
+
+ test("Persisting") {
+ var rddsToCheck = Seq.empty[RDDToCheck]
+
+ val rdd1 = createRDD(sc)
+ val checkpointer = new PeriodicRDDCheckpointer[Double](10, rdd1.sparkContext)
+ checkpointer.update(rdd1)
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
+ checkPersistence(rddsToCheck, 1)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val rdd = createRDD(sc)
+ checkpointer.update(rdd)
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
+ checkPersistence(rddsToCheck, iteration)
+ iteration += 1
+ }
+ }
+
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ val checkpointInterval = 2
+ var rddsToCheck = Seq.empty[RDDToCheck]
+ sc.setCheckpointDir(path)
+ val rdd1 = createRDD(sc)
+ val checkpointer = new PeriodicRDDCheckpointer[Double](checkpointInterval, rdd1.sparkContext)
+ checkpointer.update(rdd1)
+ rdd1.count()
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd1, 1)
+ checkCheckpoint(rddsToCheck, 1, checkpointInterval)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val rdd = createRDD(sc)
+ checkpointer.update(rdd)
+ rdd.count()
+ rddsToCheck = rddsToCheck :+ RDDToCheck(rdd, iteration)
+ checkCheckpoint(rddsToCheck, iteration, checkpointInterval)
+ iteration += 1
+ }
+
+ checkpointer.deleteAllCheckpoints()
+ rddsToCheck.foreach { rdd =>
+ confirmCheckpointRemoved(rdd.rdd)
+ }
+
+ Utils.deleteRecursively(tempDir)
+ }
+}
+
+private object PeriodicRDDCheckpointerSuite {
+
+ case class RDDToCheck(rdd: RDD[Double], gIndex: Int)
+
+ def createRDD(sc: SparkContext): RDD[Double] = {
+ sc.parallelize(Seq(0.0, 1.0, 2.0, 3.0))
+ }
+
+ def checkPersistence(rdds: Seq[RDDToCheck], iteration: Int): Unit = {
+ rdds.foreach { g =>
+ checkPersistence(g.rdd, g.gIndex, iteration)
+ }
+ }
+
+ /**
+ * Check storage level of rdd.
+ * @param gIndex Index of rdd in order inserted into checkpointer (from 1).
+ * @param iteration Total number of rdds inserted into checkpointer.
+ */
+ def checkPersistence(rdd: RDD[_], gIndex: Int, iteration: Int): Unit = {
+ try {
+ if (gIndex + 2 < iteration) {
+ assert(rdd.getStorageLevel == StorageLevel.NONE)
+ } else {
+ assert(rdd.getStorageLevel != StorageLevel.NONE)
+ }
+ } catch {
+ case _: AssertionError =>
+ throw new Exception(s"PeriodicRDDCheckpointerSuite.checkPersistence failed with:\n" +
+ s"\t gIndex = $gIndex\n" +
+ s"\t iteration = $iteration\n" +
+ s"\t rdd.getStorageLevel = ${rdd.getStorageLevel}\n")
+ }
+ }
+
+ def checkCheckpoint(rdds: Seq[RDDToCheck], iteration: Int, checkpointInterval: Int): Unit = {
+ rdds.reverse.foreach { g =>
+ checkCheckpoint(g.rdd, g.gIndex, iteration, checkpointInterval)
+ }
+ }
+
+ def confirmCheckpointRemoved(rdd: RDD[_]): Unit = {
+ // Note: We cannot check rdd.isCheckpointed since that value is never updated.
+ // Instead, we check for the presence of the checkpoint files.
+ // This test should continue to work even after this rdd.isCheckpointed issue
+ // is fixed (though it can then be simplified and not look for the files).
+ val fs = FileSystem.get(rdd.sparkContext.hadoopConfiguration)
+ rdd.getCheckpointFile.foreach { checkpointFile =>
+ assert(!fs.exists(new Path(checkpointFile)), "RDD checkpoint file should have been removed")
+ }
+ }
+
+ /**
+ * Check checkpointed status of rdd.
+ * @param gIndex Index of rdd in order inserted into checkpointer (from 1).
+ * @param iteration Total number of rdds inserted into checkpointer.
+ */
+ def checkCheckpoint(
+ rdd: RDD[_],
+ gIndex: Int,
+ iteration: Int,
+ checkpointInterval: Int): Unit = {
+ try {
+ if (gIndex % checkpointInterval == 0) {
+ // We allow 2 checkpoint intervals since we perform an action (checkpointing a second rdd)
+ // only AFTER PeriodicRDDCheckpointer decides whether to remove the previous checkpoint.
+ if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
+ assert(rdd.isCheckpointed, "RDD should be checkpointed")
+ assert(rdd.getCheckpointFile.nonEmpty, "RDD should have 2 checkpoint files")
+ } else {
+ confirmCheckpointRemoved(rdd)
+ }
+ } else {
+ // RDD should never be checkpointed
+ assert(!rdd.isCheckpointed, "RDD should never have been checkpointed")
+ assert(rdd.getCheckpointFile.isEmpty, "RDD should not have any checkpoint files")
+ }
+ } catch {
+ case e: AssertionError =>
+ throw new Exception(s"PeriodicRDDCheckpointerSuite.checkCheckpoint failed with:\n" +
+ s"\t gIndex = $gIndex\n" +
+ s"\t iteration = $iteration\n" +
+ s"\t checkpointInterval = $checkpointInterval\n" +
+ s"\t rdd.isCheckpointed = ${rdd.isCheckpointed}\n" +
+ s"\t rdd.getCheckpointFile = ${rdd.getCheckpointFile.mkString(", ")}\n" +
+ s" AssertionError message: ${e.getMessage}")
+ }
+ }
+
+}