aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala33
-rw-r--r--python/pyspark/mllib/clustering.py66
2 files changed, 98 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index e628059c4a..c58a64001d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -503,6 +503,39 @@ private[python] class PythonMLLibAPI extends Serializable {
}
/**
+ * Java stub for Python mllib LDA.run()
+ */
+ def trainLDAModel(
+ data: JavaRDD[java.util.List[Any]],
+ k: Int,
+ maxIterations: Int,
+ docConcentration: Double,
+ topicConcentration: Double,
+ seed: java.lang.Long,
+ checkpointInterval: Int,
+ optimizer: String): LDAModel = {
+ val algo = new LDA()
+ .setK(k)
+ .setMaxIterations(maxIterations)
+ .setDocConcentration(docConcentration)
+ .setTopicConcentration(topicConcentration)
+ .setCheckpointInterval(checkpointInterval)
+ .setOptimizer(optimizer)
+
+ if (seed != null) algo.setSeed(seed)
+
+ val documents = data.rdd.map(_.asScala.toArray).map { r =>
+ r(0) match {
+ case i: java.lang.Integer => (i.toLong, r(1).asInstanceOf[Vector])
+ case i: java.lang.Long => (i.toLong, r(1).asInstanceOf[Vector])
+ case _ => throw new IllegalArgumentException("input values contains invalid type value.")
+ }
+ }
+ algo.run(documents)
+ }
+
+
+ /**
* Java stub for Python mllib FPGrowth.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
* needs to be taken in the Python code to ensure it gets freed on exit; see
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index ed4d78a2c6..8a92f6911c 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -31,13 +31,15 @@ from pyspark import SparkContext
from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, callJavaFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector
+from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.stat.distribution import MultivariateGaussian
from pyspark.mllib.util import Saveable, Loader, inherit_doc, JavaLoader, JavaSaveable
from pyspark.streaming import DStream
__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture',
'PowerIterationClusteringModel', 'PowerIterationClustering',
- 'StreamingKMeans', 'StreamingKMeansModel']
+ 'StreamingKMeans', 'StreamingKMeansModel',
+ 'LDA', 'LDAModel']
@inherit_doc
@@ -563,6 +565,68 @@ class StreamingKMeans(object):
return dstream.mapValues(lambda x: self._model.predict(x))
+class LDAModel(JavaModelWrapper):
+
+ """ A clustering model derived from the LDA method.
+
+ Latent Dirichlet Allocation (LDA), a topic model designed for text documents.
+ Terminology
+ - "word" = "term": an element of the vocabulary
+ - "token": instance of a term appearing in a document
+ - "topic": multinomial distribution over words representing some concept
+ References:
+ - Original LDA paper (journal version):
+ Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
+
+ >>> from pyspark.mllib.linalg import Vectors
+ >>> from numpy.testing import assert_almost_equal
+ >>> data = [
+ ... [1, Vectors.dense([0.0, 1.0])],
+ ... [2, SparseVector(2, {0: 1.0})],
+ ... ]
+ >>> rdd = sc.parallelize(data)
+ >>> model = LDA.train(rdd, k=2)
+ >>> model.vocabSize()
+ 2
+ >>> topics = model.topicsMatrix()
+ >>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]])
+ >>> assert_almost_equal(topics, topics_expect, 1)
+ """
+
+ def topicsMatrix(self):
+ """Inferred topics, where each topic is represented by a distribution over terms."""
+ return self.call("topicsMatrix").toArray()
+
+ def vocabSize(self):
+ """Vocabulary size (number of terms or terms in the vocabulary)"""
+ return self.call("vocabSize")
+
+
+class LDA(object):
+
+ @classmethod
+ def train(cls, rdd, k=10, maxIterations=20, docConcentration=-1.0,
+ topicConcentration=-1.0, seed=None, checkpointInterval=10, optimizer="em"):
+ """Train a LDA model.
+
+ :param rdd: RDD of data points
+ :param k: Number of clusters you want
+ :param maxIterations: Number of iterations. Default to 20
+ :param docConcentration: Concentration parameter (commonly named "alpha")
+ for the prior placed on documents' distributions over topics ("theta").
+ :param topicConcentration: Concentration parameter (commonly named "beta" or "eta")
+ for the prior placed on topics' distributions over terms.
+ :param seed: Random Seed
+ :param checkpointInterval: Period (in iterations) between checkpoints.
+ :param optimizer: LDAOptimizer used to perform the actual calculation.
+ Currently "em", "online" are supported. Default to "em".
+ """
+ model = callMLlibFunc("trainLDAModel", rdd, k, maxIterations,
+ docConcentration, topicConcentration, seed,
+ checkpointInterval, optimizer)
+ return LDAModel(model)
+
+
def _test():
import doctest
import pyspark.mllib.clustering