aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-02 23:57:35 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-02 23:57:37 -0800
commit980764f3c0c065cc32454a036e8d0ead5a92037b (patch)
tree916561cd9f7939191d663c2d5a6f097321da5bae /mllib/src/test
parent0cc7b88c99405db99bc4c3d66f5409e5da0e3c6e (diff)
downloadspark-980764f3c0c065cc32454a036e8d0ead5a92037b.tar.gz
spark-980764f3c0c065cc32454a036e8d0ead5a92037b.tar.bz2
spark-980764f3c0c065cc32454a036e8d0ead5a92037b.zip
[SPARK-1405] [mllib] Latent Dirichlet Allocation (LDA) using EM
**This PR introduces an API + simple implementation for Latent Dirichlet Allocation (LDA).** The [design doc for this PR](https://docs.google.com/document/d/1kSsDqTeZMEB94Bs4GTd0mvdAmduvZSSkpoSfn-seAzo) has been updated since I initially posted it. In particular, see the API and Planning for the Future sections. * Settle on a public API which may eventually include: * more inference algorithms * more options / functionality * Have an initial easy-to-understand implementation which others may improve. * This is NOT intended to support every topic model out there. However, if there are suggestions for making this extensible or pluggable in the future, that could be nice, as long as it does not complicate the API or implementation too much. * This may not be very scalable currently. It will be important to check and improve accuracy. For correctness of the implementation, please check against the Asuncion et al. (2009) paper in the design doc. **Dependency: This makes MLlib depend on GraphX.** Files and classes: * LDA.scala (441 lines): * class LDA (main estimator class) * LDA.Document (text + document ID) * LDAModel.scala (266 lines) * abstract class LDAModel * class LocalLDAModel * class DistributedLDAModel * LDAExample.scala (245 lines): script to run LDA + a simple (private) Tokenizer * LDASuite.scala (144 lines) Data/model representation and algorithm: * Data/model: Uses GraphX, with term vertices + document vertices * Algorithm: EM, following [Asuncion, Welling, Smyth, and Teh. "On Smoothing and Inference for Topic Models." UAI, 2009.](http://arxiv-web3.library.cornell.edu/abs/1205.2662v1) * For more details, please see the description in the “DEVELOPERS NOTE” in LDA.scala Please refer to the JIRA for more discussion + the [design doc for this PR](https://docs.google.com/document/d/1kSsDqTeZMEB94Bs4GTd0mvdAmduvZSSkpoSfn-seAzo) Here, I list the main changes AFTER the design doc was posted. Design decisions: * logLikelihood() computes the log likelihood of the data and the current point estimate of parameters. This is different from the likelihood of the data given the hyperparameters, which would be harder to compute. I’d describe the current approach as more frequentist, whereas the harder approach would be more Bayesian. * The current API takes Documents as token count vectors. I believe there should be an extended API taking RDD[String] or RDD[Array[String]] in a future PR. I have sketched this out in the design doc (as well as handier versions of getTopics returning Strings). * Hyperparameters should be set differently for different inference/learning algorithms. See Asuncion et al. (2009) in the design doc for a good demonstration. I encourage good behavior via defaults and warning messages. Items planned for future PRs: * perplexity * API taking Strings * Should LDA be called LatentDirichletAllocation (and LDAModel be LatentDirichletAllocationModel)? * Pro: We may someday want LinearDiscriminantAnalysis. * Con: Very long names * Should LDA reside in clustering? Or do we want a sub-package? * mllib.topicmodel * mllib.clustering.topicmodel * Does the API seem reasonable and extensible? * Unit tests: * Should there be a test which checks a clustering results? E.g., train on a small, fake dataset with 2 very distinct topics/clusters, and ensure LDA finds those 2 topics/clusters. Does that sound useful or too flaky? This has not been tested much for scaling. I have run it on a laptop for 200 iterations on a 5MB dataset with 1000 terms and 5 topics. Running it for 500 iterations made it fail because of GC problems. I'm running larger scale tests & will put results here, but future PRs may need to improve the scaling. * dlwh for the initial implementation * + jegonzal for some code in the initial implementation * The many contributors towards topic model implementations in Spark which were referenced as a basis for this PR: akopich witgo yinxusen dlwh EntilZha jegonzal IlyaKozlov * Note: The plan is to include this full list in the authors if this PR gets merged. Please notify me if you prefer otherwise. CC: mengxr Authors: Joseph K. Bradley <joseph@databricks.com> Joseph Gonzalez <joseph.e.gonzalez@gmail.com> David Hall <david.lw.hall@gmail.com> Guoqiang Li <witgo@qq.com> Xiangrui Meng <meng@databricks.com> Pedro Rodriguez <pedro@snowgeek.org> Avanesov Valeriy <acopich@gmail.com> Xusen Yin <yinxusen@gmail.com> Closes #2388 Closes #4047 from jkbradley/davidhall-lda and squashes the following commits: 77e8814 [Joseph K. Bradley] small doc fix 5c74345 [Joseph K. Bradley] cleaned up doc based on code review 589728b [Joseph K. Bradley] Updates per code review. Main change was in LDAExample for faster vocab computation. Also updated PeriodicGraphCheckpointerSuite.scala to clean up checkpoint files at end e3980d2 [Joseph K. Bradley] cleaned up PeriodicGraphCheckpointerSuite.scala 74487e5 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into davidhall-lda 4ae2a7d [Joseph K. Bradley] removed duplicate graphx dependency in mllib/pom.xml e391474 [Joseph K. Bradley] Removed LDATiming. Added PeriodicGraphCheckpointerSuite.scala. Small LDA cleanups. e8d8acf [Joseph K. Bradley] Added catch for BreakIterator exception. Improved preprocessing to reduce passes over data 1a231b4 [Joseph K. Bradley] fixed scalastyle 91aadfe [Joseph K. Bradley] Added Java-friendly run method to LDA. Added Java test suite for LDA. Changed LDAModel.describeTopics to return Java-friendly type b75472d [Joseph K. Bradley] merged improvements from LDATiming into LDAExample. Will remove LDATiming after done testing 993ca56 [Joseph K. Bradley] * Removed Document type in favor of (Long, Vector) * Changed doc ID restriction to be: id must be nonnegative and unique in the doc (instead of 0,1,2,...) * Add checks for valid ranges of eta, alpha * Rename “LearningState” to “EMOptimizer” * Renamed params: termSmoothing -> topicConcentration, topicSmoothing -> docConcentration * Also added aliases alpha, beta cb5a319 [Joseph K. Bradley] Added checkpointing to LDA * new class PeriodicGraphCheckpointer * params checkpointDir, checkpointInterval to LDA 43c1c40 [Joseph K. Bradley] small cleanup 0b90393 [Joseph K. Bradley] renamed LDA LearningState.collectTopicTotals to globalTopicTotals 77a2c85 [Joseph K. Bradley] Moved auto term,topic smoothing computation to get*Smoothing methods. Changed word to term in some places. Updated LDAExample to use default smoothing amounts. fb1e7b5 [Xiangrui Meng] minor 08d59a3 [Xiangrui Meng] reset spacing 9fe0b95 [Xiangrui Meng] optimize aggregateMessages cec0a9c [Xiangrui Meng] * -> *= 6cb11b0 [Xiangrui Meng] optimize computePTopic 9eb3d02 [Xiangrui Meng] + -> += 892530c [Xiangrui Meng] use axpy 45cc7f2 [Xiangrui Meng] mapPart -> flatMap ce53be9 [Joseph K. Bradley] fixed example name 75749e7 [Joseph K. Bradley] scala style fix 9f2a492 [Joseph K. Bradley] Unit tests and fixes for LDA, now ready for PR 377ebd9 [Joseph K. Bradley] separated LDA models into own file. more cleanups before PR 2d40006 [Joseph K. Bradley] cleanups before PR 2891e89 [Joseph K. Bradley] Prepped LDA main class for PR, but some cleanups remain 0cb7187 [Joseph K. Bradley] Added 3 files from dlwh LDA implementation
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java119
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala153
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala187
3 files changed, 459 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
new file mode 100644
index 0000000000..dc10aa67c7
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaLDASuite.java
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+
+import org.apache.spark.api.java.JavaRDD;
+import scala.Tuple2;
+
+import org.junit.After;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertArrayEquals;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.mllib.linalg.Vector;
+
+
+public class JavaLDASuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaLDA");
+ ArrayList<Tuple2<Long, Vector>> tinyCorpus = new ArrayList<Tuple2<Long, Vector>>();
+ for (int i = 0; i < LDASuite$.MODULE$.tinyCorpus().length; i++) {
+ tinyCorpus.add(new Tuple2<Long, Vector>((Long)LDASuite$.MODULE$.tinyCorpus()[i]._1(),
+ LDASuite$.MODULE$.tinyCorpus()[i]._2()));
+ }
+ JavaRDD<Tuple2<Long, Vector>> tmpCorpus = sc.parallelize(tinyCorpus, 2);
+ corpus = JavaPairRDD.fromJavaRDD(tmpCorpus);
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void localLDAModel() {
+ LocalLDAModel model = new LocalLDAModel(LDASuite$.MODULE$.tinyTopics());
+
+ // Check: basic parameters
+ assertEquals(model.k(), tinyK);
+ assertEquals(model.vocabSize(), tinyVocabSize);
+ assertEquals(model.topicsMatrix(), tinyTopics);
+
+ // Check: describeTopics() with all terms
+ Tuple2<int[], double[]>[] fullTopicSummary = model.describeTopics();
+ assertEquals(fullTopicSummary.length, tinyK);
+ for (int i = 0; i < fullTopicSummary.length; i++) {
+ assertArrayEquals(fullTopicSummary[i]._1(), tinyTopicDescription[i]._1());
+ assertArrayEquals(fullTopicSummary[i]._2(), tinyTopicDescription[i]._2(), 1e-5);
+ }
+ }
+
+ @Test
+ public void distributedLDAModel() {
+ int k = 3;
+ double topicSmoothing = 1.2;
+ double termSmoothing = 1.2;
+
+ // Train a model
+ LDA lda = new LDA();
+ lda.setK(k)
+ .setDocConcentration(topicSmoothing)
+ .setTopicConcentration(termSmoothing)
+ .setMaxIterations(5)
+ .setSeed(12345);
+
+ DistributedLDAModel model = lda.run(corpus);
+
+ // Check: basic parameters
+ LocalLDAModel localModel = model.toLocal();
+ assertEquals(model.k(), k);
+ assertEquals(localModel.k(), k);
+ assertEquals(model.vocabSize(), tinyVocabSize);
+ assertEquals(localModel.vocabSize(), tinyVocabSize);
+ assertEquals(model.topicsMatrix(), localModel.topicsMatrix());
+
+ // Check: topic summaries
+ Tuple2<int[], double[]>[] roundedTopicSummary = model.describeTopics();
+ assertEquals(roundedTopicSummary.length, k);
+ Tuple2<int[], double[]>[] roundedLocalTopicSummary = localModel.describeTopics();
+ assertEquals(roundedLocalTopicSummary.length, k);
+
+ // Check: log probabilities
+ assert(model.logLikelihood() < 0.0);
+ assert(model.logPrior() < 0.0);
+ }
+
+ private static int tinyK = LDASuite$.MODULE$.tinyK();
+ private static int tinyVocabSize = LDASuite$.MODULE$.tinyVocabSize();
+ private static Matrix tinyTopics = LDASuite$.MODULE$.tinyTopics();
+ private static Tuple2<int[], double[]>[] tinyTopicDescription =
+ LDASuite$.MODULE$.tinyTopicDescription();
+ JavaPairRDD<Long, Vector> corpus;
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
new file mode 100644
index 0000000000..302d751eb8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class LDASuite extends FunSuite with MLlibTestSparkContext {
+
+ import LDASuite._
+
+ test("LocalLDAModel") {
+ val model = new LocalLDAModel(tinyTopics)
+
+ // Check: basic parameters
+ assert(model.k === tinyK)
+ assert(model.vocabSize === tinyVocabSize)
+ assert(model.topicsMatrix === tinyTopics)
+
+ // Check: describeTopics() with all terms
+ val fullTopicSummary = model.describeTopics()
+ assert(fullTopicSummary.size === tinyK)
+ fullTopicSummary.zip(tinyTopicDescription).foreach {
+ case ((algTerms, algTermWeights), (terms, termWeights)) =>
+ assert(algTerms === terms)
+ assert(algTermWeights === termWeights)
+ }
+
+ // Check: describeTopics() with some terms
+ val smallNumTerms = 3
+ val smallTopicSummary = model.describeTopics(maxTermsPerTopic = smallNumTerms)
+ smallTopicSummary.zip(tinyTopicDescription).foreach {
+ case ((algTerms, algTermWeights), (terms, termWeights)) =>
+ assert(algTerms === terms.slice(0, smallNumTerms))
+ assert(algTermWeights === termWeights.slice(0, smallNumTerms))
+ }
+ }
+
+ test("running and DistributedLDAModel") {
+ val k = 3
+ val topicSmoothing = 1.2
+ val termSmoothing = 1.2
+
+ // Train a model
+ val lda = new LDA()
+ lda.setK(k)
+ .setDocConcentration(topicSmoothing)
+ .setTopicConcentration(termSmoothing)
+ .setMaxIterations(5)
+ .setSeed(12345)
+ val corpus = sc.parallelize(tinyCorpus, 2)
+
+ val model: DistributedLDAModel = lda.run(corpus)
+
+ // Check: basic parameters
+ val localModel = model.toLocal
+ assert(model.k === k)
+ assert(localModel.k === k)
+ assert(model.vocabSize === tinyVocabSize)
+ assert(localModel.vocabSize === tinyVocabSize)
+ assert(model.topicsMatrix === localModel.topicsMatrix)
+
+ // Check: topic summaries
+ // The odd decimal formatting and sorting is a hack to do a robust comparison.
+ val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) =>
+ // cut values to 3 digits after the decimal place
+ terms.zip(termWeights).map { case (term, weight) =>
+ ("%.3f".format(weight).toDouble, term.toInt)
+ }
+ }.sortBy(_.mkString(""))
+ val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
+ // cut values to 3 digits after the decimal place
+ terms.zip(termWeights).map { case (term, weight) =>
+ ("%.3f".format(weight).toDouble, term.toInt)
+ }
+ }.sortBy(_.mkString(""))
+ roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) =>
+ assert(t1 === t2)
+ }
+
+ // Check: per-doc topic distributions
+ val topicDistributions = model.topicDistributions.collect()
+ // Ensure all documents are covered.
+ assert(topicDistributions.size === tinyCorpus.size)
+ assert(tinyCorpus.map(_._1).toSet === topicDistributions.map(_._1).toSet)
+ // Ensure we have proper distributions
+ topicDistributions.foreach { case (docId, topicDistribution) =>
+ assert(topicDistribution.size === tinyK)
+ assert(topicDistribution.toArray.sum ~== 1.0 absTol 1e-5)
+ }
+
+ // Check: log probabilities
+ assert(model.logLikelihood < 0.0)
+ assert(model.logPrior < 0.0)
+ }
+
+ test("vertex indexing") {
+ // Check vertex ID indexing and conversions.
+ val docIds = Array(0, 1, 2)
+ val docVertexIds = docIds
+ val termIds = Array(0, 1, 2)
+ val termVertexIds = Array(-1, -2, -3)
+ assert(docVertexIds.forall(i => !LDA.isTermVertex((i.toLong, 0))))
+ assert(termIds.map(LDA.term2index) === termVertexIds)
+ assert(termVertexIds.map(i => LDA.index2term(i.toLong)) === termIds)
+ assert(termVertexIds.forall(i => LDA.isTermVertex((i.toLong, 0))))
+ }
+}
+
+private[clustering] object LDASuite {
+
+ def tinyK: Int = 3
+ def tinyVocabSize: Int = 5
+ def tinyTopicsAsArray: Array[Array[Double]] = Array(
+ Array[Double](0.1, 0.2, 0.3, 0.4, 0.0), // topic 0
+ Array[Double](0.5, 0.05, 0.05, 0.1, 0.3), // topic 1
+ Array[Double](0.2, 0.2, 0.05, 0.05, 0.5) // topic 2
+ )
+ def tinyTopics: Matrix = new DenseMatrix(numRows = tinyVocabSize, numCols = tinyK,
+ values = tinyTopicsAsArray.fold(Array.empty[Double])(_ ++ _))
+ def tinyTopicDescription: Array[(Array[Int], Array[Double])] = tinyTopicsAsArray.map { topic =>
+ val (termWeights, terms) = topic.zipWithIndex.sortBy(-_._1).unzip
+ (terms.toArray, termWeights.toArray)
+ }
+
+ def tinyCorpus = Array(
+ Vectors.dense(1, 3, 0, 2, 8),
+ Vectors.dense(0, 2, 1, 0, 4),
+ Vectors.dense(2, 3, 12, 3, 1),
+ Vectors.dense(0, 3, 1, 9, 8),
+ Vectors.dense(1, 1, 4, 2, 6)
+ ).zipWithIndex.map { case (wordCounts, docId) => (docId.toLong, wordCounts) }
+ assert(tinyCorpus.forall(_._2.size == tinyVocabSize)) // sanity check for test data
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
new file mode 100644
index 0000000000..dac28a369b
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.impl
+
+import org.scalatest.FunSuite
+
+import org.apache.hadoop.fs.{FileSystem, Path}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.graphx.{Edge, Graph}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+
+
+class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext {
+
+ import PeriodicGraphCheckpointerSuite._
+
+ // TODO: Do I need to call count() on the graphs' RDDs?
+
+ test("Persisting") {
+ var graphsToCheck = Seq.empty[GraphToCheck]
+
+ val graph1 = createGraph(sc)
+ val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10)
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+ checkPersistence(graphsToCheck, 1)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val graph = createGraph(sc)
+ checkpointer.updateGraph(graph)
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+ checkPersistence(graphsToCheck, iteration)
+ iteration += 1
+ }
+ }
+
+ test("Checkpointing") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ val checkpointInterval = 2
+ var graphsToCheck = Seq.empty[GraphToCheck]
+
+ val graph1 = createGraph(sc)
+ val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), checkpointInterval)
+ graph1.edges.count()
+ graph1.vertices.count()
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
+ checkCheckpoint(graphsToCheck, 1, checkpointInterval)
+
+ var iteration = 2
+ while (iteration < 9) {
+ val graph = createGraph(sc)
+ checkpointer.updateGraph(graph)
+ graph.vertices.count()
+ graph.edges.count()
+ graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration)
+ checkCheckpoint(graphsToCheck, iteration, checkpointInterval)
+ iteration += 1
+ }
+
+ checkpointer.deleteAllCheckpoints()
+ graphsToCheck.foreach { graph =>
+ confirmCheckpointRemoved(graph.graph)
+ }
+
+ Utils.deleteRecursively(tempDir)
+ }
+}
+
+private object PeriodicGraphCheckpointerSuite {
+
+ case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int)
+
+ val edges = Seq(
+ Edge[Double](0, 1, 0),
+ Edge[Double](1, 2, 0),
+ Edge[Double](2, 3, 0),
+ Edge[Double](3, 4, 0))
+
+ def createGraph(sc: SparkContext): Graph[Double, Double] = {
+ Graph.fromEdges[Double, Double](sc.parallelize(edges), 0)
+ }
+
+ def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = {
+ graphs.foreach { g =>
+ checkPersistence(g.graph, g.gIndex, iteration)
+ }
+ }
+
+ /**
+ * Check storage level of graph.
+ * @param gIndex Index of graph in order inserted into checkpointer (from 1).
+ * @param iteration Total number of graphs inserted into checkpointer.
+ */
+ def checkPersistence(graph: Graph[_, _], gIndex: Int, iteration: Int): Unit = {
+ try {
+ if (gIndex + 2 < iteration) {
+ assert(graph.vertices.getStorageLevel == StorageLevel.NONE)
+ assert(graph.edges.getStorageLevel == StorageLevel.NONE)
+ } else {
+ assert(graph.vertices.getStorageLevel != StorageLevel.NONE)
+ assert(graph.edges.getStorageLevel != StorageLevel.NONE)
+ }
+ } catch {
+ case _: AssertionError =>
+ throw new Exception(s"PeriodicGraphCheckpointerSuite.checkPersistence failed with:\n" +
+ s"\t gIndex = $gIndex\n" +
+ s"\t iteration = $iteration\n" +
+ s"\t graph.vertices.getStorageLevel = ${graph.vertices.getStorageLevel}\n" +
+ s"\t graph.edges.getStorageLevel = ${graph.edges.getStorageLevel}\n")
+ }
+ }
+
+ def checkCheckpoint(graphs: Seq[GraphToCheck], iteration: Int, checkpointInterval: Int): Unit = {
+ graphs.reverse.foreach { g =>
+ checkCheckpoint(g.graph, g.gIndex, iteration, checkpointInterval)
+ }
+ }
+
+ def confirmCheckpointRemoved(graph: Graph[_, _]): Unit = {
+ // Note: We cannot check graph.isCheckpointed since that value is never updated.
+ // Instead, we check for the presence of the checkpoint files.
+ // This test should continue to work even after this graph.isCheckpointed issue
+ // is fixed (though it can then be simplified and not look for the files).
+ val fs = FileSystem.get(graph.vertices.sparkContext.hadoopConfiguration)
+ graph.getCheckpointFiles.foreach { checkpointFile =>
+ assert(!fs.exists(new Path(checkpointFile)),
+ "Graph checkpoint file should have been removed")
+ }
+ }
+
+ /**
+ * Check checkpointed status of graph.
+ * @param gIndex Index of graph in order inserted into checkpointer (from 1).
+ * @param iteration Total number of graphs inserted into checkpointer.
+ */
+ def checkCheckpoint(
+ graph: Graph[_, _],
+ gIndex: Int,
+ iteration: Int,
+ checkpointInterval: Int): Unit = {
+ try {
+ if (gIndex % checkpointInterval == 0) {
+ // We allow 2 checkpoint intervals since we perform an action (checkpointing a second graph)
+ // only AFTER PeriodicGraphCheckpointer decides whether to remove the previous checkpoint.
+ if (iteration - 2 * checkpointInterval < gIndex && gIndex <= iteration) {
+ assert(graph.isCheckpointed, "Graph should be checkpointed")
+ assert(graph.getCheckpointFiles.length == 2, "Graph should have 2 checkpoint files")
+ } else {
+ confirmCheckpointRemoved(graph)
+ }
+ } else {
+ // Graph should never be checkpointed
+ assert(!graph.isCheckpointed, "Graph should never have been checkpointed")
+ assert(graph.getCheckpointFiles.length == 0, "Graph should not have any checkpoint files")
+ }
+ } catch {
+ case e: AssertionError =>
+ throw new Exception(s"PeriodicGraphCheckpointerSuite.checkCheckpoint failed with:\n" +
+ s"\t gIndex = $gIndex\n" +
+ s"\t iteration = $iteration\n" +
+ s"\t checkpointInterval = $checkpointInterval\n" +
+ s"\t graph.isCheckpointed = ${graph.isCheckpointed}\n" +
+ s"\t graph.getCheckpointFiles = ${graph.getCheckpointFiles.mkString(", ")}\n" +
+ s" AssertionError message: ${e.getMessage}")
+ }
+ }
+
+}