aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-12 17:03:19 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-12 17:03:19 -0800
commitdcb896fd8cec83483f700ee985c352be61cdf233 (patch)
tree42af6acc1a3c9f916aa154b564fd8d548df060c9 /mllib/src
parentbc092966f8264c6685b3300461cb79dd6a509ecf (diff)
downloadspark-dcb896fd8cec83483f700ee985c352be61cdf233.tar.gz
spark-dcb896fd8cec83483f700ee985c352be61cdf233.tar.bz2
spark-dcb896fd8cec83483f700ee985c352be61cdf233.zip
[SPARK-11712][ML] Make spark.ml LDAModel be abstract
Per discussion in the initial Pipelines LDA PR [https://github.com/apache/spark/pull/9513], we should make LDAModel abstract and create a LocalLDAModel. This code simplification should be done before the 1.6 release to ensure API compatibility in future releases. CC feynmanliang mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9678 from jkbradley/lda-pipelines-2.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala180
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala4
2 files changed, 96 insertions, 88 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index f66233ed3d..92e05815d6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -314,31 +314,31 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* Model fitted by [[LDA]].
*
* @param vocabSize Vocabulary size (number of terms or terms in the vocabulary)
- * @param oldLocalModel Underlying spark.mllib model.
- * If this model was produced by Online LDA, then this is the
- * only model representation.
- * If this model was produced by EM, then this local
- * representation may be built lazily.
* @param sqlContext Used to construct local DataFrames for returning query results
*/
@Since("1.6.0")
@Experimental
-class LDAModel private[ml] (
+sealed abstract class LDAModel private[ml] (
@Since("1.6.0") override val uid: String,
@Since("1.6.0") val vocabSize: Int,
- @Since("1.6.0") protected var oldLocalModel: Option[OldLocalLDAModel],
@Since("1.6.0") @transient protected val sqlContext: SQLContext)
extends Model[LDAModel] with LDAParams with Logging {
- /** Returns underlying spark.mllib model */
+ // NOTE to developers:
+ // This abstraction should contain all important functionality for basic LDA usage.
+ // Specializations of this class can contain expert-only functionality.
+
+ /**
+ * Underlying spark.mllib model.
+ * If this model was produced by Online LDA, then this is the only model representation.
+ * If this model was produced by EM, then this local representation may be built lazily.
+ */
@Since("1.6.0")
- protected def getModel: OldLDAModel = oldLocalModel match {
- case Some(m) => m
- case None =>
- // Should never happen.
- throw new RuntimeException("LDAModel required local model format," +
- " but the underlying model is missing.")
- }
+ protected def oldLocalModel: OldLocalLDAModel
+
+ /** Returns underlying spark.mllib model, which may be local or distributed */
+ @Since("1.6.0")
+ protected def getModel: OldLDAModel
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
@@ -352,16 +352,17 @@ class LDAModel private[ml] (
@Since("1.6.0")
def setSeed(value: Long): this.type = set(seed, value)
- @Since("1.6.0")
- override def copy(extra: ParamMap): LDAModel = {
- val copied = new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
- copyValues(copied, extra).setParent(parent)
- }
-
+ /**
+ * Transforms the input dataset.
+ *
+ * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
+ * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
+ * This implementation may be changed in the future.
+ */
@Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
if ($(topicDistributionCol).nonEmpty) {
- val t = udf(oldLocalModel.get.getTopicDistributionMethod(sqlContext.sparkContext))
+ val t = udf(oldLocalModel.getTopicDistributionMethod(sqlContext.sparkContext))
dataset.withColumn($(topicDistributionCol), t(col($(featuresCol))))
} else {
logWarning("LDAModel.transform was called without any output columns. Set an output column" +
@@ -388,56 +389,50 @@ class LDAModel private[ml] (
* This is a matrix of size vocabSize x k, where each column is a topic.
* No guarantees are given about the ordering of the topics.
*
- * WARNING: If this model is actually a [[DistributedLDAModel]] instance from EM,
- * then this method could involve collecting a large amount of data to the driver
- * (on the order of vocabSize x k).
+ * WARNING: If this model is actually a [[DistributedLDAModel]] instance produced by
+ * the Expectation-Maximization ("em") [[optimizer]], then this method could involve
+ * collecting a large amount of data to the driver (on the order of vocabSize x k).
*/
@Since("1.6.0")
- def topicsMatrix: Matrix = getModel.topicsMatrix
+ def topicsMatrix: Matrix = oldLocalModel.topicsMatrix
/** Indicates whether this instance is of type [[DistributedLDAModel]] */
@Since("1.6.0")
- def isDistributed: Boolean = false
+ def isDistributed: Boolean
/**
* Calculates a lower bound on the log likelihood of the entire corpus.
*
* See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
*
- * WARNING: If this model was learned via a [[DistributedLDAModel]], this involves collecting
- * a large [[topicsMatrix]] to the driver. This implementation may be changed in the
- * future.
+ * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
+ * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
+ * This implementation may be changed in the future.
*
* @param dataset test corpus to use for calculating log likelihood
* @return variational lower bound on the log likelihood of the entire corpus
*/
@Since("1.6.0")
- def logLikelihood(dataset: DataFrame): Double = oldLocalModel match {
- case Some(m) =>
- val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
- m.logLikelihood(oldDataset)
- case None =>
- // Should never happen.
- throw new RuntimeException("LocalLDAModel.logLikelihood was called," +
- " but the underlying model is missing.")
+ def logLikelihood(dataset: DataFrame): Double = {
+ val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
+ oldLocalModel.logLikelihood(oldDataset)
}
/**
* Calculate an upper bound bound on perplexity. (Lower is better.)
* See Equation (16) in the Online LDA paper (Hoffman et al., 2010).
*
+ * WARNING: If this model is an instance of [[DistributedLDAModel]] (produced when [[optimizer]]
+ * is set to "em"), this involves collecting a large [[topicsMatrix]] to the driver.
+ * This implementation may be changed in the future.
+ *
* @param dataset test corpus to use for calculating perplexity
* @return Variational upper bound on log perplexity per token.
*/
@Since("1.6.0")
- def logPerplexity(dataset: DataFrame): Double = oldLocalModel match {
- case Some(m) =>
- val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
- m.logPerplexity(oldDataset)
- case None =>
- // Should never happen.
- throw new RuntimeException("LocalLDAModel.logPerplexity was called," +
- " but the underlying model is missing.")
+ def logPerplexity(dataset: DataFrame): Double = {
+ val oldDataset = LDA.getOldDataset(dataset, $(featuresCol))
+ oldLocalModel.logPerplexity(oldDataset)
}
/**
@@ -468,10 +463,43 @@ class LDAModel private[ml] (
/**
* :: Experimental ::
*
- * Distributed model fitted by [[LDA]] using Expectation-Maximization (EM).
+ * Local (non-distributed) model fitted by [[LDA]].
+ *
+ * This model stores the inferred topics only; it does not store info about the training dataset.
+ */
+@Since("1.6.0")
+@Experimental
+class LocalLDAModel private[ml] (
+ uid: String,
+ vocabSize: Int,
+ @Since("1.6.0") override protected val oldLocalModel: OldLocalLDAModel,
+ sqlContext: SQLContext)
+ extends LDAModel(uid, vocabSize, sqlContext) {
+
+ @Since("1.6.0")
+ override def copy(extra: ParamMap): LocalLDAModel = {
+ val copied = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
+ copyValues(copied, extra).setParent(parent).asInstanceOf[LocalLDAModel]
+ }
+
+ override protected def getModel: OldLDAModel = oldLocalModel
+
+ @Since("1.6.0")
+ override def isDistributed: Boolean = false
+}
+
+
+/**
+ * :: Experimental ::
+ *
+ * Distributed model fitted by [[LDA]].
+ * This type of model is currently only produced by Expectation-Maximization (EM).
*
* This model stores the inferred topics, the full training dataset, and the topic distribution
* for each training document.
+ *
+ * @param oldLocalModelOption Used to implement [[oldLocalModel]] as a lazy val, but keeping
+ * [[copy()]] cheap.
*/
@Since("1.6.0")
@Experimental
@@ -479,59 +507,39 @@ class DistributedLDAModel private[ml] (
uid: String,
vocabSize: Int,
private val oldDistributedModel: OldDistributedLDAModel,
- sqlContext: SQLContext)
- extends LDAModel(uid, vocabSize, None, sqlContext) {
+ sqlContext: SQLContext,
+ private var oldLocalModelOption: Option[OldLocalLDAModel])
+ extends LDAModel(uid, vocabSize, sqlContext) {
+
+ override protected def oldLocalModel: OldLocalLDAModel = {
+ if (oldLocalModelOption.isEmpty) {
+ oldLocalModelOption = Some(oldDistributedModel.toLocal)
+ }
+ oldLocalModelOption.get
+ }
+
+ override protected def getModel: OldLDAModel = oldDistributedModel
/**
* Convert this distributed model to a local representation. This discards info about the
* training dataset.
+ *
+ * WARNING: This involves collecting a large [[topicsMatrix]] to the driver.
*/
@Since("1.6.0")
- def toLocal: LDAModel = {
- if (oldLocalModel.isEmpty) {
- oldLocalModel = Some(oldDistributedModel.toLocal)
- }
- new LDAModel(uid, vocabSize, oldLocalModel, sqlContext)
- }
-
- @Since("1.6.0")
- override protected def getModel: OldLDAModel = oldDistributedModel
+ def toLocal: LocalLDAModel = new LocalLDAModel(uid, vocabSize, oldLocalModel, sqlContext)
@Since("1.6.0")
override def copy(extra: ParamMap): DistributedLDAModel = {
- val copied = new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext)
- if (oldLocalModel.nonEmpty) copied.oldLocalModel = oldLocalModel
+ val copied =
+ new DistributedLDAModel(uid, vocabSize, oldDistributedModel, sqlContext, oldLocalModelOption)
copyValues(copied, extra).setParent(parent)
copied
}
@Since("1.6.0")
- override def topicsMatrix: Matrix = {
- if (oldLocalModel.isEmpty) {
- oldLocalModel = Some(oldDistributedModel.toLocal)
- }
- super.topicsMatrix
- }
-
- @Since("1.6.0")
override def isDistributed: Boolean = true
- @Since("1.6.0")
- override def logLikelihood(dataset: DataFrame): Double = {
- if (oldLocalModel.isEmpty) {
- oldLocalModel = Some(oldDistributedModel.toLocal)
- }
- super.logLikelihood(dataset)
- }
-
- @Since("1.6.0")
- override def logPerplexity(dataset: DataFrame): Double = {
- if (oldLocalModel.isEmpty) {
- oldLocalModel = Some(oldDistributedModel.toLocal)
- }
- super.logPerplexity(dataset)
- }
-
/**
* Log likelihood of the observed tokens in the training set,
* given the current parameter estimates:
@@ -673,9 +681,9 @@ class LDA @Since("1.6.0") (
val oldModel = oldLDA.run(oldData)
val newModel = oldModel match {
case m: OldLocalLDAModel =>
- new LDAModel(uid, m.vocabSize, Some(m), dataset.sqlContext)
+ new LocalLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
case m: OldDistributedLDAModel =>
- new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext)
+ new DistributedLDAModel(uid, m.vocabSize, m, dataset.sqlContext, None)
}
copyValues(newModel).setParent(this)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index edb927495e..b634d31cc3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -156,7 +156,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
MLTestingUtils.checkCopy(model)
- assert(!model.isInstanceOf[DistributedLDAModel])
+ assert(model.isInstanceOf[LocalLDAModel])
assert(model.vocabSize === vocabSize)
assert(model.estimatedDocConcentration.size === k)
assert(model.topicsMatrix.numRows === vocabSize)
@@ -210,7 +210,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.isDistributed)
val localModel = model.toLocal
- assert(!localModel.isInstanceOf[DistributedLDAModel])
+ assert(localModel.isInstanceOf[LocalLDAModel])
// training logLikelihood, logPrior
val ll = model.trainingLogLikelihood