aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorJeff Zhang <zjffdu@gmail.com>2016-04-29 10:42:52 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-29 10:42:52 -0700
commit775772de36d5b7e80595aad850aa1dcea8791688 (patch)
tree2b2f67da23565dd7c1ac2d6758bfe502c2c76cd5 /python/pyspark/ml/tests.py
parentf08dcdb8d33d2a40573547ae8543e409b6ab9e59 (diff)
downloadspark-775772de36d5b7e80595aad850aa1dcea8791688.tar.gz
spark-775772de36d5b7e80595aad850aa1dcea8791688.tar.bz2
spark-775772de36d5b7e80595aad850aa1dcea8791688.zip
[SPARK-11940][PYSPARK][ML] Python API for ml.clustering.LDA PR2
## What changes were proposed in this pull request? pyspark.ml API for LDA * LDA, LDAModel, LocalLDAModel, DistributedLDAModel * includes persistence This replaces [https://github.com/apache/spark/pull/10242] ## How was this patch tested? * doc test for LDA, including Param setters * unit test for persistence Author: Joseph K. Bradley <joseph@databricks.com> Author: Jeff Zhang <zjffdu@apache.org> Closes #12723 from jkbradley/zjffdu-SPARK-11940.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py57
1 files changed, 56 insertions, 1 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 36cecd4682..e7d4c0af45 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -46,7 +46,7 @@ from pyspark import keyword_only
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml.classification import (
LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel)
-from pyspark.ml.clustering import KMeans
+from pyspark.ml.clustering import *
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params, TypeConverters
@@ -809,6 +809,61 @@ class PersistenceTest(PySparkTestCase):
pass
+class LDATest(PySparkTestCase):
+
+ def _compare(self, m1, m2):
+ """
+ Temp method for comparing instances.
+ TODO: Replace with generic implementation once SPARK-14706 is merged.
+ """
+ self.assertEqual(m1.uid, m2.uid)
+ self.assertEqual(type(m1), type(m2))
+ self.assertEqual(len(m1.params), len(m2.params))
+ for p in m1.params:
+ if m1.isDefined(p):
+ self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
+ self.assertEqual(p.parent, m2.getParam(p.name).parent)
+ if isinstance(m1, LDAModel):
+ self.assertEqual(m1.vocabSize(), m2.vocabSize())
+ self.assertEqual(m1.topicsMatrix(), m2.topicsMatrix())
+
+ def test_persistence(self):
+ # Test save/load for LDA, LocalLDAModel, DistributedLDAModel.
+ sqlContext = SQLContext(self.sc)
+ df = sqlContext.createDataFrame([
+ [1, Vectors.dense([0.0, 1.0])],
+ [2, Vectors.sparse(2, {0: 1.0})],
+ ], ["id", "features"])
+ # Fit model
+ lda = LDA(k=2, seed=1, optimizer="em")
+ distributedModel = lda.fit(df)
+ self.assertTrue(distributedModel.isDistributed())
+ localModel = distributedModel.toLocal()
+ self.assertFalse(localModel.isDistributed())
+ # Define paths
+ path = tempfile.mkdtemp()
+ lda_path = path + "/lda"
+ dist_model_path = path + "/distLDAModel"
+ local_model_path = path + "/localLDAModel"
+ # Test LDA
+ lda.save(lda_path)
+ lda2 = LDA.load(lda_path)
+ self._compare(lda, lda2)
+ # Test DistributedLDAModel
+ distributedModel.save(dist_model_path)
+ distributedModel2 = DistributedLDAModel.load(dist_model_path)
+ self._compare(distributedModel, distributedModel2)
+ # Test LocalLDAModel
+ localModel.save(local_model_path)
+ localModel2 = LocalLDAModel.load(local_model_path)
+ self._compare(localModel, localModel2)
+ # Clean up
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+
class TrainingSummaryTest(PySparkTestCase):
def test_linear_regression_summary(self):