aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-02-05 15:07:33 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-05 15:07:33 -0800
commitc19152cd2a5d407ecf526a90e3bb059f09905b3a (patch)
tree7b65a0a4e066eb6089f789cceb18f5603f6a0a77 /mllib/src
parent62371adaa5b9251579db7300504506975689610c (diff)
downloadspark-c19152cd2a5d407ecf526a90e3bb059f09905b3a.tar.gz
spark-c19152cd2a5d407ecf526a90e3bb059f09905b3a.tar.bz2
spark-c19152cd2a5d407ecf526a90e3bb059f09905b3a.zip
[SPARK-5604[MLLIB] remove checkpointDir from LDA
`checkpointDir` is a Spark global configuration. Users should set it outside LDA. This PR also hides some methods under `private[clustering] object LDA`, so they don't show up in the generated Java doc (SPARK-5610). jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #4390 from mengxr/SPARK-5604 and squashes the following commits: a34bb39 [Xiangrui Meng] remove checkpointDir from LDA
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala73
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala6
3 files changed, 23 insertions, 64 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index d8f82867a0..a1d3df03a1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -52,6 +52,9 @@ import org.apache.spark.util.Utils
* - Paper which clearly explains several algorithms, including EM:
* Asuncion, Welling, Smyth, and Teh.
* "On Smoothing and Inference for Topic Models." UAI, 2009.
+ *
+ * @see [[http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation Latent Dirichlet allocation
+ * (Wikipedia)]]
*/
@Experimental
class LDA private (
@@ -60,11 +63,10 @@ class LDA private (
private var docConcentration: Double,
private var topicConcentration: Double,
private var seed: Long,
- private var checkpointDir: Option[String],
private var checkpointInterval: Int) extends Logging {
def this() = this(k = 10, maxIterations = 20, docConcentration = -1, topicConcentration = -1,
- seed = Utils.random.nextLong(), checkpointDir = None, checkpointInterval = 10)
+ seed = Utils.random.nextLong(), checkpointInterval = 10)
/**
* Number of topics to infer. I.e., the number of soft cluster centers.
@@ -201,49 +203,17 @@ class LDA private (
}
/**
- * Directory for storing checkpoint files during learning.
- * This is not necessary, but checkpointing helps with recovery (when nodes fail).
- * It also helps with eliminating temporary shuffle files on disk, which can be important when
- * LDA is run for many iterations.
- */
- def getCheckpointDir: Option[String] = checkpointDir
-
- /**
- * Directory for storing checkpoint files during learning.
- * This is not necessary, but checkpointing helps with recovery (when nodes fail).
- * It also helps with eliminating temporary shuffle files on disk, which can be important when
- * LDA is run for many iterations.
- *
- * NOTE: If the [[org.apache.spark.SparkContext.checkpointDir]] is already set, then the value
- * given to LDA is ignored, and the existing directory is kept.
- *
- * (default = None)
- */
- def setCheckpointDir(checkpointDir: String): this.type = {
- this.checkpointDir = Some(checkpointDir)
- this
- }
-
- /**
- * Clear the directory for storing checkpoint files during learning.
- * If one is already set in the [[org.apache.spark.SparkContext]], then checkpointing will still
- * occur; otherwise, no checkpointing will be used.
- */
- def clearCheckpointDir(): this.type = {
- this.checkpointDir = None
- this
- }
-
- /**
* Period (in iterations) between checkpoints.
- * @see [[getCheckpointDir]]
*/
def getCheckpointInterval: Int = checkpointInterval
/**
- * Period (in iterations) between checkpoints.
- * (default = 10)
- * @see [[getCheckpointDir]]
+ * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery
+ * (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be
+ * important when LDA is run for many iterations. If the checkpoint directory is not set in
+ * [[org.apache.spark.SparkContext]], this setting is ignored.
+ *
+ * @see [[org.apache.spark.SparkContext#setCheckpointDir]]
*/
def setCheckpointInterval(checkpointInterval: Int): this.type = {
this.checkpointInterval = checkpointInterval
@@ -261,7 +231,7 @@ class LDA private (
*/
def run(documents: RDD[(Long, Vector)]): DistributedLDAModel = {
val state = LDA.initialState(documents, k, getDocConcentration, getTopicConcentration, seed,
- checkpointDir, checkpointInterval)
+ checkpointInterval)
var iter = 0
val iterationTimes = Array.fill[Double](maxIterations)(0)
while (iter < maxIterations) {
@@ -337,18 +307,18 @@ private[clustering] object LDA {
* Vector over topics (length k) of token counts.
* The meaning of these counts can vary, and it may or may not be normalized to be a distribution.
*/
- type TopicCounts = BDV[Double]
+ private[clustering] type TopicCounts = BDV[Double]
- type TokenCount = Double
+ private[clustering] type TokenCount = Double
/** Term vertex IDs are {-1, -2, ..., -vocabSize} */
- def term2index(term: Int): Long = -(1 + term.toLong)
+ private[clustering] def term2index(term: Int): Long = -(1 + term.toLong)
- def index2term(termIndex: Long): Int = -(1 + termIndex).toInt
+ private[clustering] def index2term(termIndex: Long): Int = -(1 + termIndex).toInt
- def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0
+ private[clustering] def isDocumentVertex(v: (VertexId, _)): Boolean = v._1 >= 0
- def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
+ private[clustering] def isTermVertex(v: (VertexId, _)): Boolean = v._1 < 0
/**
* Optimizer for EM algorithm which stores data + parameter graph, plus algorithm parameters.
@@ -360,17 +330,16 @@ private[clustering] object LDA {
* @param docConcentration "alpha"
* @param topicConcentration "beta" or "eta"
*/
- class EMOptimizer(
+ private[clustering] class EMOptimizer(
var graph: Graph[TopicCounts, TokenCount],
val k: Int,
val vocabSize: Int,
val docConcentration: Double,
val topicConcentration: Double,
- checkpointDir: Option[String],
checkpointInterval: Int) {
private[LDA] val graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
- graph, checkpointDir, checkpointInterval)
+ graph, checkpointInterval)
def next(): EMOptimizer = {
val eta = topicConcentration
@@ -468,7 +437,6 @@ private[clustering] object LDA {
docConcentration: Double,
topicConcentration: Double,
randomSeed: Long,
- checkpointDir: Option[String],
checkpointInterval: Int): EMOptimizer = {
// For each document, create an edge (Document -> Term) for each unique term in the document.
val edges: RDD[Edge[TokenCount]] = docs.flatMap { case (docID: Long, termCounts: Vector) =>
@@ -512,8 +480,7 @@ private[clustering] object LDA {
val graph = Graph(docVertices ++ termVertices, edges)
.partitionBy(PartitionStrategy.EdgePartition1D)
- new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointDir,
- checkpointInterval)
+ new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
index 76672fe51e..6e5dd119dd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala
@@ -74,7 +74,6 @@ import org.apache.spark.storage.StorageLevel
* }}}
*
* @param currentGraph Initial graph
- * @param checkpointDir The directory for storing checkpoint files
* @param checkpointInterval Graphs will be checkpointed at this interval
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
@@ -83,7 +82,6 @@ import org.apache.spark.storage.StorageLevel
*/
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
var currentGraph: Graph[VD, ED],
- val checkpointDir: Option[String],
val checkpointInterval: Int) extends Logging {
/** FIFO queue of past checkpointed RDDs */
@@ -101,12 +99,6 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](
*/
private val sc = currentGraph.vertices.sparkContext
- // If a checkpoint directory is given, and there's no prior checkpoint directory,
- // then set the checkpoint directory with the given one.
- if (checkpointDir.nonEmpty && sc.getCheckpointDir.isEmpty) {
- sc.setCheckpointDir(checkpointDir.get)
- }
-
updateGraph(currentGraph)
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
index dac28a369b..699f009f0f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala
@@ -38,7 +38,7 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext
var graphsToCheck = Seq.empty[GraphToCheck]
val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer(graph1, None, 10)
+ val checkpointer = new PeriodicGraphCheckpointer(graph1, 10)
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)
checkPersistence(graphsToCheck, 1)
@@ -57,9 +57,9 @@ class PeriodicGraphCheckpointerSuite extends FunSuite with MLlibTestSparkContext
val path = tempDir.toURI.toString
val checkpointInterval = 2
var graphsToCheck = Seq.empty[GraphToCheck]
-
+ sc.setCheckpointDir(path)
val graph1 = createGraph(sc)
- val checkpointer = new PeriodicGraphCheckpointer(graph1, Some(path), checkpointInterval)
+ val checkpointer = new PeriodicGraphCheckpointer(graph1, checkpointInterval)
graph1.edges.count()
graph1.vertices.count()
graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1)