aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-11-24 09:56:17 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-24 09:56:17 -0800
commit52bc25c8e26d4be250d8ff7864067528f4f98592 (patch)
treee694ac3ea0585288f0b75055819043a1d76e1a8c /mllib/src/main/scala/org
parent9e24ba667e43290fbaa3cacb93cf5d9be790f1fd (diff)
downloadspark-52bc25c8e26d4be250d8ff7864067528f4f98592.tar.gz
spark-52bc25c8e26d4be250d8ff7864067528f4f98592.tar.bz2
spark-52bc25c8e26d4be250d8ff7864067528f4f98592.zip
[SPARK-11847][ML] Model export/import for spark.ml: LDA
Add read/write support to LDA, similar to ALS. save/load for ml.LocalLDAModel is done. For DistributedLDAModel, I'm not sure if we can invoke save on the mllib.DistributedLDAModel directly. I'll send update after some test. Author: Yuhao Yang <hhbyyh@gmail.com> Closes #9894 from hhbyyh/ldaMLsave.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala110
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala4
2 files changed, 108 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 92e05815d6..830510b169 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -17,12 +17,13 @@
package org.apache.spark.ml.clustering
+import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter}
import org.apache.spark.ml.param._
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel,
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
@@ -322,7 +323,7 @@ sealed abstract class LDAModel private[ml] (
@Since("1.6.0") override val uid: String,
@Since("1.6.0") val vocabSize: Int,
@Since("1.6.0") @transient protected val sqlContext: SQLContext)
- extends Model[LDAModel] with LDAParams with Logging {
+ extends Model[LDAModel] with LDAParams with Logging with MLWritable {
// NOTE to developers:
// This abstraction should contain all important functionality for basic LDA usage.
@@ -486,6 +487,64 @@ class LocalLDAModel private[ml] (
@Since("1.6.0")
override def isDistributed: Boolean = false
+
+ @Since("1.6.0")
+ override def write: MLWriter = new LocalLDAModel.LocalLDAModelWriter(this)
+}
+
+
+@Since("1.6.0")
+object LocalLDAModel extends MLReadable[LocalLDAModel] {
+
+ private[LocalLDAModel]
+ class LocalLDAModelWriter(instance: LocalLDAModel) extends MLWriter {
+
+ private case class Data(
+ vocabSize: Int,
+ topicsMatrix: Matrix,
+ docConcentration: Vector,
+ topicConcentration: Double,
+ gammaShape: Double)
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val oldModel = instance.oldLocalModel
+ val data = Data(instance.vocabSize, oldModel.topicsMatrix, oldModel.docConcentration,
+ oldModel.topicConcentration, oldModel.gammaShape)
+ val dataPath = new Path(path, "data").toString
+ sqlContext.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
+ }
+
+ private class LocalLDAModelReader extends MLReader[LocalLDAModel] {
+
+ private val className = classOf[LocalLDAModel].getName
+
+ override def load(path: String): LocalLDAModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val dataPath = new Path(path, "data").toString
+ val data = sqlContext.read.parquet(dataPath)
+ .select("vocabSize", "topicsMatrix", "docConcentration", "topicConcentration",
+ "gammaShape")
+ .head()
+ val vocabSize = data.getAs[Int](0)
+ val topicsMatrix = data.getAs[Matrix](1)
+ val docConcentration = data.getAs[Vector](2)
+ val topicConcentration = data.getAs[Double](3)
+ val gammaShape = data.getAs[Double](4)
+ val oldModel = new OldLocalLDAModel(topicsMatrix, docConcentration, topicConcentration,
+ gammaShape)
+ val model = new LocalLDAModel(metadata.uid, vocabSize, oldModel, sqlContext)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: MLReader[LocalLDAModel] = new LocalLDAModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): LocalLDAModel = super.load(path)
}
@@ -562,6 +621,45 @@ class DistributedLDAModel private[ml] (
*/
@Since("1.6.0")
lazy val logPrior: Double = oldDistributedModel.logPrior
+
+ @Since("1.6.0")
+ override def write: MLWriter = new DistributedLDAModel.DistributedWriter(this)
+}
+
+
+@Since("1.6.0")
+object DistributedLDAModel extends MLReadable[DistributedLDAModel] {
+
+ private[DistributedLDAModel]
+ class DistributedWriter(instance: DistributedLDAModel) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ DefaultParamsWriter.saveMetadata(instance, path, sc)
+ val modelPath = new Path(path, "oldModel").toString
+ instance.oldDistributedModel.save(sc, modelPath)
+ }
+ }
+
+ private class DistributedLDAModelReader extends MLReader[DistributedLDAModel] {
+
+ private val className = classOf[DistributedLDAModel].getName
+
+ override def load(path: String): DistributedLDAModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val modelPath = new Path(path, "oldModel").toString
+ val oldModel = OldDistributedLDAModel.load(sc, modelPath)
+ val model = new DistributedLDAModel(
+ metadata.uid, oldModel.vocabSize, oldModel, sqlContext, None)
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+
+ @Since("1.6.0")
+ override def read: MLReader[DistributedLDAModel] = new DistributedLDAModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): DistributedLDAModel = super.load(path)
}
@@ -593,7 +691,8 @@ class DistributedLDAModel private[ml] (
@Since("1.6.0")
@Experimental
class LDA @Since("1.6.0") (
- @Since("1.6.0") override val uid: String) extends Estimator[LDAModel] with LDAParams {
+ @Since("1.6.0") override val uid: String)
+ extends Estimator[LDAModel] with LDAParams with DefaultParamsWritable {
@Since("1.6.0")
def this() = this(Identifiable.randomUID("lda"))
@@ -695,7 +794,7 @@ class LDA @Since("1.6.0") (
}
-private[clustering] object LDA {
+private[clustering] object LDA extends DefaultParamsReadable[LDA] {
/** Get dataset for spark.mllib LDA */
def getOldDataset(dataset: DataFrame, featuresCol: String): RDD[(Long, Vector)] = {
@@ -706,4 +805,7 @@ private[clustering] object LDA {
(docId, features)
}
}
+
+ @Since("1.6.0")
+ override def load(path: String): LDA = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index cd520f09bd..7384d065a2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -187,11 +187,11 @@ abstract class LDAModel private[clustering] extends Saveable {
* @param topics Inferred topics (vocabSize x k matrix).
*/
@Since("1.3.0")
-class LocalLDAModel private[clustering] (
+class LocalLDAModel private[spark] (
@Since("1.3.0") val topics: Matrix,
@Since("1.5.0") override val docConcentration: Vector,
@Since("1.5.0") override val topicConcentration: Double,
- override protected[clustering] val gammaShape: Double = 100)
+ override protected[spark] val gammaShape: Double = 100)
extends LDAModel with Serializable {
@Since("1.3.0")