diff options
author | Feynman Liang <fliang@databricks.com> | 2015-07-29 19:02:15 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-07-29 19:02:15 -0700 |
commit | a200e64561c8803731578267df16906f6773cbea (patch) | |
tree | 838571920db0948b88e385c8ed1faa415d11af75 /mllib/src/test | |
parent | 2a9fe4a4e7acbe4c9d3b6c6e61ff46d1472ee5f4 (diff) | |
download | spark-a200e64561c8803731578267df16906f6773cbea.tar.gz spark-a200e64561c8803731578267df16906f6773cbea.tar.bz2 spark-a200e64561c8803731578267df16906f6773cbea.zip |
[SPARK-9440] [MLLIB] Add hyperparameters to LocalLDAModel save/load
jkbradley MechCoder
Resolves blocking issue for SPARK-6793. Please review after #7705 is merged.
Author: Feynman Liang <fliang@databricks.com>
Closes #7757 from feynmanliang/SPARK-9940-localSaveLoad and squashes the following commits:
d0d8cf4 [Feynman Liang] Fix thisClassName
0f30109 [Feynman Liang] Fix tests after changing LDAModel public API
dc61981 [Feynman Liang] Add hyperparams to LocalLDAModel save/load
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index aa36336ebb..b91c7cefed 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -334,7 +334,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { test("model save/load") { // Test for LocalLDAModel. val localModel = new LocalLDAModel(tinyTopics, - Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D) + Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D) val tempDir1 = Utils.createTempDir() val path1 = tempDir1.toURI.toString @@ -360,6 +360,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(samelocalModel.topicsMatrix === localModel.topicsMatrix) assert(samelocalModel.k === localModel.k) assert(samelocalModel.vocabSize === localModel.vocabSize) + assert(samelocalModel.docConcentration === localModel.docConcentration) + assert(samelocalModel.topicConcentration === localModel.topicConcentration) + assert(samelocalModel.gammaShape === localModel.gammaShape) val sameDistributedModel = DistributedLDAModel.load(sc, path2) assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix) @@ -368,6 +371,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes) assert(distributedModel.docConcentration === sameDistributedModel.docConcentration) assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration) + assert(distributedModel.gammaShape === sameDistributedModel.gammaShape) assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals) val graph = distributedModel.graph |