aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/mllib-clustering.md10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala228
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala41
3 files changed, 274 insertions, 5 deletions
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 0fc7036bff..bb875ae2ae 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -472,7 +472,7 @@ to the algorithm. We then output the topics, represented as probability distribu
<div data-lang="scala" markdown="1">
{% highlight scala %}
-import org.apache.spark.mllib.clustering.LDA
+import org.apache.spark.mllib.clustering.{LDA, DistributedLDAModel}
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
@@ -492,6 +492,11 @@ for (topic <- Range(0, 3)) {
for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); }
println()
}
+
+// Save and load model.
+ldaModel.save(sc, "myLDAModel")
+val sameModel = DistributedLDAModel.load(sc, "myLDAModel")
+
{% endhighlight %}
</div>
@@ -551,6 +556,9 @@ public class JavaLDAExample {
}
System.out.println();
}
+
+ ldaModel.save(sc.sc(), "myLDAModel");
+ DistributedLDAModel sameModel = DistributedLDAModel.load(sc.sc(), "myLDAModel");
}
}
{% endhighlight %}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 974b26924d..920b57756b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -17,15 +17,25 @@
package org.apache.spark.mllib.clustering
-import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum}
+import breeze.linalg.{DenseMatrix => BDM, normalize, sum => brzSum, DenseVector => BDV}
+import org.apache.hadoop.fs.Path
+
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD
-import org.apache.spark.graphx.{VertexId, EdgeContext, Graph}
-import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
+import org.apache.spark.graphx.{VertexId, Edge, EdgeContext, Graph}
+import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix, DenseVector}
+import org.apache.spark.mllib.util.{Saveable, Loader}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, Row}
import org.apache.spark.util.BoundedPriorityQueue
+
/**
* :: Experimental ::
*
@@ -35,7 +45,7 @@ import org.apache.spark.util.BoundedPriorityQueue
* including local and distributed data structures.
*/
@Experimental
-abstract class LDAModel private[clustering] {
+abstract class LDAModel private[clustering] extends Saveable {
/** Number of topics */
def k: Int
@@ -176,6 +186,11 @@ class LocalLDAModel private[clustering] (
}.toArray
}
+ override protected def formatVersion = "1.0"
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ LocalLDAModel.SaveLoadV1_0.save(sc, path, topicsMatrix)
+ }
// TODO
// override def logLikelihood(documents: RDD[(Long, Vector)]): Double = ???
@@ -184,6 +199,80 @@ class LocalLDAModel private[clustering] (
}
+@Experimental
+object LocalLDAModel extends Loader[LocalLDAModel] {
+
+ private object SaveLoadV1_0 {
+
+ val thisFormatVersion = "1.0"
+
+ val thisClassName = "org.apache.spark.mllib.clustering.LocalLDAModel"
+
+ // Store the distribution of terms of each topic and the column index in topicsMatrix
+ // as a Row in data.
+ case class Data(topic: Vector, index: Int)
+
+ def save(sc: SparkContext, path: String, topicsMatrix: Matrix): Unit = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+
+ val k = topicsMatrix.numCols
+ val metadata = compact(render
+ (("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("k" -> k) ~ ("vocabSize" -> topicsMatrix.numRows)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val topicsDenseMatrix = topicsMatrix.toBreeze.toDenseMatrix
+ val topics = Range(0, k).map { topicInd =>
+ Data(Vectors.dense((topicsDenseMatrix(::, topicInd).toArray)), topicInd)
+ }.toSeq
+ sc.parallelize(topics, 1).toDF().write.parquet(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): LocalLDAModel = {
+ val dataPath = Loader.dataPath(path)
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val dataFrame = sqlContext.read.parquet(dataPath)
+
+ Loader.checkSchema[Data](dataFrame.schema)
+ val topics = dataFrame.collect()
+ val vocabSize = topics(0).getAs[Vector](0).size
+ val k = topics.size
+
+ val brzTopics = BDM.zeros[Double](vocabSize, k)
+ topics.foreach { case Row(vec: Vector, ind: Int) =>
+ brzTopics(::, ind) := vec.toBreeze
+ }
+ new LocalLDAModel(Matrices.fromBreeze(brzTopics))
+ }
+ }
+
+ override def load(sc: SparkContext, path: String): LocalLDAModel = {
+ val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val expectedK = (metadata \ "k").extract[Int]
+ val expectedVocabSize = (metadata \ "vocabSize").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+
+ val model = (loadedClassName, loadedVersion) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ SaveLoadV1_0.load(sc, path)
+ case _ => throw new Exception(
+ s"LocalLDAModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $loadedVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+
+ val topicsMatrix = model.topicsMatrix
+ require(expectedK == topicsMatrix.numCols,
+ s"LocalLDAModel requires $expectedK topics, got ${topicsMatrix.numCols} topics")
+ require(expectedVocabSize == topicsMatrix.numRows,
+ s"LocalLDAModel requires $expectedVocabSize terms for each topic, " +
+ s"but got ${topicsMatrix.numRows}")
+ model
+ }
+}
+
/**
* :: Experimental ::
*
@@ -354,4 +443,135 @@ class DistributedLDAModel private (
// TODO:
// override def topicDistributions(documents: RDD[(Long, Vector)]): RDD[(Long, Vector)] = ???
+ override protected def formatVersion = "1.0"
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ DistributedLDAModel.SaveLoadV1_0.save(
+ sc, path, graph, globalTopicTotals, k, vocabSize, docConcentration, topicConcentration,
+ iterationTimes)
+ }
+}
+
+
+@Experimental
+object DistributedLDAModel extends Loader[DistributedLDAModel] {
+
+ private object SaveLoadV1_0 {
+
+ val thisFormatVersion = "1.0"
+
+ val classNameV1_0 = "org.apache.spark.mllib.clustering.DistributedLDAModel"
+
+ // Store globalTopicTotals as a Vector.
+ case class Data(globalTopicTotals: Vector)
+
+ // Store each term and document vertex with an id and the topicWeights.
+ case class VertexData(id: Long, topicWeights: Vector)
+
+ // Store each edge with the source id, destination id and tokenCounts.
+ case class EdgeData(srcId: Long, dstId: Long, tokenCounts: Double)
+
+ def save(
+ sc: SparkContext,
+ path: String,
+ graph: Graph[LDA.TopicCounts, LDA.TokenCount],
+ globalTopicTotals: LDA.TopicCounts,
+ k: Int,
+ vocabSize: Int,
+ docConcentration: Double,
+ topicConcentration: Double,
+ iterationTimes: Array[Double]): Unit = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+
+ val metadata = compact(render
+ (("class" -> classNameV1_0) ~ ("version" -> thisFormatVersion) ~
+ ("k" -> k) ~ ("vocabSize" -> vocabSize) ~ ("docConcentration" -> docConcentration) ~
+ ("topicConcentration" -> topicConcentration) ~
+ ("iterationTimes" -> iterationTimes.toSeq)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val newPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
+ sc.parallelize(Seq(Data(Vectors.fromBreeze(globalTopicTotals)))).toDF()
+ .write.parquet(newPath)
+
+ val verticesPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
+ graph.vertices.map { case (ind, vertex) =>
+ VertexData(ind, Vectors.fromBreeze(vertex))
+ }.toDF().write.parquet(verticesPath)
+
+ val edgesPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
+ graph.edges.map { case Edge(srcId, dstId, prop) =>
+ EdgeData(srcId, dstId, prop)
+ }.toDF().write.parquet(edgesPath)
+ }
+
+ def load(
+ sc: SparkContext,
+ path: String,
+ vocabSize: Int,
+ docConcentration: Double,
+ topicConcentration: Double,
+ iterationTimes: Array[Double]): DistributedLDAModel = {
+ val dataPath = new Path(Loader.dataPath(path), "globalTopicTotals").toUri.toString
+ val vertexDataPath = new Path(Loader.dataPath(path), "topicCounts").toUri.toString
+ val edgeDataPath = new Path(Loader.dataPath(path), "tokenCounts").toUri.toString
+ val sqlContext = SQLContext.getOrCreate(sc)
+ val dataFrame = sqlContext.read.parquet(dataPath)
+ val vertexDataFrame = sqlContext.read.parquet(vertexDataPath)
+ val edgeDataFrame = sqlContext.read.parquet(edgeDataPath)
+
+ Loader.checkSchema[Data](dataFrame.schema)
+ Loader.checkSchema[VertexData](vertexDataFrame.schema)
+ Loader.checkSchema[EdgeData](edgeDataFrame.schema)
+ val globalTopicTotals: LDA.TopicCounts =
+ dataFrame.first().getAs[Vector](0).toBreeze.toDenseVector
+ val vertices: RDD[(VertexId, LDA.TopicCounts)] = vertexDataFrame.map {
+ case Row(ind: Long, vec: Vector) => (ind, vec.toBreeze.toDenseVector)
+ }
+
+ val edges: RDD[Edge[LDA.TokenCount]] = edgeDataFrame.map {
+ case Row(srcId: Long, dstId: Long, prop: Double) => Edge(srcId, dstId, prop)
+ }
+ val graph: Graph[LDA.TopicCounts, LDA.TokenCount] = Graph(vertices, edges)
+
+ new DistributedLDAModel(graph, globalTopicTotals, globalTopicTotals.length, vocabSize,
+ docConcentration, topicConcentration, iterationTimes)
+ }
+
+ }
+
+ override def load(sc: SparkContext, path: String): DistributedLDAModel = {
+ val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val expectedK = (metadata \ "k").extract[Int]
+ val vocabSize = (metadata \ "vocabSize").extract[Int]
+ val docConcentration = (metadata \ "docConcentration").extract[Double]
+ val topicConcentration = (metadata \ "topicConcentration").extract[Double]
+ val iterationTimes = (metadata \ "iterationTimes").extract[Seq[Double]]
+ val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+
+ val model = (loadedClassName, loadedVersion) match {
+ case (className, "1.0") if className == classNameV1_0 => {
+ DistributedLDAModel.SaveLoadV1_0.load(
+ sc, path, vocabSize, docConcentration, topicConcentration, iterationTimes.toArray)
+ }
+ case _ => throw new Exception(
+ s"DistributedLDAModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $loadedVersion). Supported: ($classNameV1_0, 1.0)")
+ }
+
+ require(model.vocabSize == vocabSize,
+ s"DistributedLDAModel requires $vocabSize vocabSize, got ${model.vocabSize} vocabSize")
+ require(model.docConcentration == docConcentration,
+ s"DistributedLDAModel requires $docConcentration docConcentration, " +
+ s"got ${model.docConcentration} docConcentration")
+ require(model.topicConcentration == topicConcentration,
+ s"DistributedLDAModel requires $topicConcentration docConcentration, " +
+ s"got ${model.topicConcentration} docConcentration")
+ require(expectedK == model.k,
+ s"DistributedLDAModel requires $expectedK topics, got ${model.k} topics")
+ model
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index 03a8a2538b..721a065658 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, DenseMatrix, Matrix, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -217,6 +218,46 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("model save/load") {
+ // Test for LocalLDAModel.
+ val localModel = new LocalLDAModel(tinyTopics)
+ val tempDir1 = Utils.createTempDir()
+ val path1 = tempDir1.toURI.toString
+
+ // Test for DistributedLDAModel.
+ val k = 3
+ val docConcentration = 1.2
+ val topicConcentration = 1.5
+ val lda = new LDA()
+ lda.setK(k)
+ .setDocConcentration(docConcentration)
+ .setTopicConcentration(topicConcentration)
+ .setMaxIterations(5)
+ .setSeed(12345)
+ val corpus = sc.parallelize(tinyCorpus, 2)
+ val distributedModel: DistributedLDAModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
+ val tempDir2 = Utils.createTempDir()
+ val path2 = tempDir2.toURI.toString
+
+ try {
+ localModel.save(sc, path1)
+ distributedModel.save(sc, path2)
+ val samelocalModel = LocalLDAModel.load(sc, path1)
+ assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
+ assert(samelocalModel.k === localModel.k)
+ assert(samelocalModel.vocabSize === localModel.vocabSize)
+
+ val sameDistributedModel = DistributedLDAModel.load(sc, path2)
+ assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
+ assert(distributedModel.k === sameDistributedModel.k)
+ assert(distributedModel.vocabSize === sameDistributedModel.vocabSize)
+ assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
+ } finally {
+ Utils.deleteRecursively(tempDir1)
+ Utils.deleteRecursively(tempDir2)
+ }
+ }
+
}
private[clustering] object LDASuite {