aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-07-22 17:22:12 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-22 17:22:12 -0700
commit5307c9d3f7a35c0276b72e743e3a62a44d2bd0f5 (patch)
tree6e5b776bd0aea58e9548135347d25efbd4c18715 /python
parent430cd7815dc7875edd126af4b90752ba8a380cf2 (diff)
downloadspark-5307c9d3f7a35c0276b72e743e3a62a44d2bd0f5.tar.gz
spark-5307c9d3f7a35c0276b72e743e3a62a44d2bd0f5.tar.bz2
spark-5307c9d3f7a35c0276b72e743e3a62a44d2bd0f5.zip
[SPARK-9223] [PYSPARK] [MLLIB] Support model save/load in LDA
Since save / load has been merged in LDA, it takes no time to write the wrappers in Python as well. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7587 from MechCoder/python_lda_save_load and squashes the following commits: c8e4ea7 [MechCoder] [SPARK-9223] [PySpark] Support model save/load in LDA
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/clustering.py43
1 files changed, 42 insertions, 1 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 8a92f6911c..58ad99d46e 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -20,6 +20,7 @@ import array as pyarray
if sys.version > '3':
xrange = range
+ basestring = str
from math import exp, log
@@ -579,7 +580,7 @@ class LDAModel(JavaModelWrapper):
Blei, Ng, and Jordan. "Latent Dirichlet Allocation." JMLR, 2003.
>>> from pyspark.mllib.linalg import Vectors
- >>> from numpy.testing import assert_almost_equal
+ >>> from numpy.testing import assert_almost_equal, assert_equal
>>> data = [
... [1, Vectors.dense([0.0, 1.0])],
... [2, SparseVector(2, {0: 1.0})],
@@ -591,6 +592,19 @@ class LDAModel(JavaModelWrapper):
>>> topics = model.topicsMatrix()
>>> topics_expect = array([[0.5, 0.5], [0.5, 0.5]])
>>> assert_almost_equal(topics, topics_expect, 1)
+
+ >>> import os, tempfile
+ >>> from shutil import rmtree
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = LDAModel.load(sc, path)
+ >>> assert_equal(sameModel.topicsMatrix(), model.topicsMatrix())
+ >>> sameModel.vocabSize() == model.vocabSize()
+ True
+ >>> try:
+ ... rmtree(path)
+ ... except OSError:
+ ... pass
"""
def topicsMatrix(self):
@@ -601,6 +615,33 @@ class LDAModel(JavaModelWrapper):
"""Vocabulary size (number of terms or terms in the vocabulary)"""
return self.call("vocabSize")
+ def save(self, sc, path):
+ """Save the LDAModel on to disk.
+
+ :param sc: SparkContext
+ :param path: str, path to where the model needs to be stored.
+ """
+ if not isinstance(sc, SparkContext):
+ raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
+ if not isinstance(path, basestring):
+ raise TypeError("path should be a basestring, got type %s" % type(path))
+ self._java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ """Load the LDAModel from disk.
+
+ :param sc: SparkContext
+ :param path: str, path to where the model is stored.
+ """
+ if not isinstance(sc, SparkContext):
+ raise TypeError("sc should be a SparkContext, got type %s" % type(sc))
+ if not isinstance(path, basestring):
+ raise TypeError("path should be a basestring, got type %s" % type(path))
+ java_model = sc._jvm.org.apache.spark.mllib.clustering.DistributedLDAModel.load(
+ sc._jsc.sc(), path)
+ return cls(java_model)
+
class LDA(object):