aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-07-29 19:02:15 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-29 19:02:15 -0700
commita200e64561c8803731578267df16906f6773cbea (patch)
tree838571920db0948b88e385c8ed1faa415d11af75 /mllib/src/test
parent2a9fe4a4e7acbe4c9d3b6c6e61ff46d1472ee5f4 (diff)
downloadspark-a200e64561c8803731578267df16906f6773cbea.tar.gz
spark-a200e64561c8803731578267df16906f6773cbea.tar.bz2
spark-a200e64561c8803731578267df16906f6773cbea.zip
[SPARK-9440] [MLLIB] Add hyperparameters to LocalLDAModel save/load
jkbradley MechCoder Resolves blocking issue for SPARK-6793. Please review after #7705 is merged. Author: Feynman Liang <fliang@databricks.com> Closes #7757 from feynmanliang/SPARK-9940-localSaveLoad and squashes the following commits: d0d8cf4 [Feynman Liang] Fix thisClassName 0f30109 [Feynman Liang] Fix tests after changing LDAModel public API dc61981 [Feynman Liang] Add hyperparams to LocalLDAModel save/load
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala6
1 files changed, 5 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index aa36336ebb..b91c7cefed 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -334,7 +334,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
test("model save/load") {
// Test for LocalLDAModel.
val localModel = new LocalLDAModel(tinyTopics,
- Vectors.dense(Array.fill(tinyTopics.numRows)(1.0 / tinyTopics.numRows)), 1D, 100D)
+ Vectors.dense(Array.fill(tinyTopics.numRows)(0.01)), 0.5D, 10D)
val tempDir1 = Utils.createTempDir()
val path1 = tempDir1.toURI.toString
@@ -360,6 +360,9 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(samelocalModel.topicsMatrix === localModel.topicsMatrix)
assert(samelocalModel.k === localModel.k)
assert(samelocalModel.vocabSize === localModel.vocabSize)
+ assert(samelocalModel.docConcentration === localModel.docConcentration)
+ assert(samelocalModel.topicConcentration === localModel.topicConcentration)
+ assert(samelocalModel.gammaShape === localModel.gammaShape)
val sameDistributedModel = DistributedLDAModel.load(sc, path2)
assert(distributedModel.topicsMatrix === sameDistributedModel.topicsMatrix)
@@ -368,6 +371,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
assert(distributedModel.docConcentration === sameDistributedModel.docConcentration)
assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration)
+ assert(distributedModel.gammaShape === sameDistributedModel.gammaShape)
assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals)
val graph = distributedModel.graph