diff options
author | Jeff Zhang <zjffdu@gmail.com> | 2016-04-29 10:42:52 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-29 10:42:52 -0700 |
commit | 775772de36d5b7e80595aad850aa1dcea8791688 (patch) | |
tree | 2b2f67da23565dd7c1ac2d6758bfe502c2c76cd5 /python/pyspark/ml/tests.py | |
parent | f08dcdb8d33d2a40573547ae8543e409b6ab9e59 (diff) | |
download | spark-775772de36d5b7e80595aad850aa1dcea8791688.tar.gz spark-775772de36d5b7e80595aad850aa1dcea8791688.tar.bz2 spark-775772de36d5b7e80595aad850aa1dcea8791688.zip |
[SPARK-11940][PYSPARK][ML] Python API for ml.clustering.LDA PR2
## What changes were proposed in this pull request?
pyspark.ml API for LDA
* LDA, LDAModel, LocalLDAModel, DistributedLDAModel
* includes persistence
This replaces [https://github.com/apache/spark/pull/10242]
## How was this patch tested?
* doc test for LDA, including Param setters
* unit test for persistence
Author: Joseph K. Bradley <joseph@databricks.com>
Author: Jeff Zhang <zjffdu@apache.org>
Closes #12723 from jkbradley/zjffdu-SPARK-11940.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 57 |
1 files changed, 56 insertions, 1 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 36cecd4682..e7d4c0af45 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -46,7 +46,7 @@ from pyspark import keyword_only from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer from pyspark.ml.classification import ( LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel) -from pyspark.ml.clustering import KMeans +from pyspark.ml.clustering import * from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator from pyspark.ml.feature import * from pyspark.ml.param import Param, Params, TypeConverters @@ -809,6 +809,61 @@ class PersistenceTest(PySparkTestCase): pass +class LDATest(PySparkTestCase): + + def _compare(self, m1, m2): + """ + Temp method for comparing instances. + TODO: Replace with generic implementation once SPARK-14706 is merged. + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + if m1.isDefined(p): + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + if isinstance(m1, LDAModel): + self.assertEqual(m1.vocabSize(), m2.vocabSize()) + self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix()) + + def test_persistence(self): + # Test save/load for LDA, LocalLDAModel, DistributedLDAModel. + sqlContext = SQLContext(self.sc) + df = sqlContext.createDataFrame([ + [1, Vectors.dense([0.0, 1.0])], + [2, Vectors.sparse(2, {0: 1.0})], + ], ["id", "features"]) + # Fit model + lda = LDA(k=2, seed=1, optimizer="em") + distributedModel = lda.fit(df) + self.assertTrue(distributedModel.isDistributed()) + localModel = distributedModel.toLocal() + self.assertFalse(localModel.isDistributed()) + # Define paths + path = tempfile.mkdtemp() + lda_path = path + "/lda" + dist_model_path = path + "/distLDAModel" + local_model_path = path + "/localLDAModel" + # Test LDA + lda.save(lda_path) + lda2 = LDA.load(lda_path) + self._compare(lda, lda2) + # Test DistributedLDAModel + distributedModel.save(dist_model_path) + distributedModel2 = DistributedLDAModel.load(dist_model_path) + self._compare(distributedModel, distributedModel2) + # Test LocalLDAModel + localModel.save(local_model_path) + localModel2 = LocalLDAModel.load(local_model_path) + self._compare(localModel, localModel2) + # Clean up + try: + rmtree(path) + except OSError: + pass + + class TrainingSummaryTest(PySparkTestCase): def test_linear_regression_summary(self): |