aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-07-24 10:56:48 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-24 10:56:48 -0700
commite25312451322969ad716dddf8248b8c17f68323b (patch)
tree795efc0f5690242c12b4867151ba72182abcb601 /mllib
parentc2b50d693e469558e3b3c9cbb9d76089d259b587 (diff)
downloadspark-e25312451322969ad716dddf8248b8c17f68323b.tar.gz
spark-e25312451322969ad716dddf8248b8c17f68323b.tar.bz2
spark-e25312451322969ad716dddf8248b8c17f68323b.zip
[SPARK-9222] [MLlib] Make class instantiation variables in DistributedLDAModel private[clustering]
This makes it easier to test all the class variables of the DistributedLDAmodel. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7573 from MechCoder/lda_test and squashes the following commits: 2f1a293 [MechCoder] [SPARK-9222] [MLlib] Make class instantiation variables in DistributedLDAModel private[clustering]
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala15
2 files changed, 19 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index 920b57756b..31c1d520fd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -283,12 +283,12 @@ object LocalLDAModel extends Loader[LocalLDAModel] {
*/
@Experimental
class DistributedLDAModel private (
- private val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
- private val globalTopicTotals: LDA.TopicCounts,
+ private[clustering] val graph: Graph[LDA.TopicCounts, LDA.TokenCount],
+ private[clustering] val globalTopicTotals: LDA.TopicCounts,
val k: Int,
val vocabSize: Int,
- private val docConcentration: Double,
- private val topicConcentration: Double,
+ private[clustering] val docConcentration: Double,
+ private[clustering] val topicConcentration: Double,
private[spark] val iterationTimes: Array[Double]) extends LDAModel {
import LDA._
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index da70d9bd7c..376a87f051 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.clustering
import breeze.linalg.{DenseMatrix => BDM}
import org.apache.spark.SparkFunSuite
+import org.apache.spark.graphx.Edge
import org.apache.spark.mllib.linalg.{DenseMatrix, Matrix, Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -318,6 +319,20 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(distributedModel.k === sameDistributedModel.k)
assert(distributedModel.vocabSize === sameDistributedModel.vocabSize)
assert(distributedModel.iterationTimes === sameDistributedModel.iterationTimes)
+ assert(distributedModel.docConcentration === sameDistributedModel.docConcentration)
+ assert(distributedModel.topicConcentration === sameDistributedModel.topicConcentration)
+ assert(distributedModel.globalTopicTotals === sameDistributedModel.globalTopicTotals)
+
+ val graph = distributedModel.graph
+ val sameGraph = sameDistributedModel.graph
+ assert(graph.vertices.sortByKey().collect() === sameGraph.vertices.sortByKey().collect())
+ val edge = graph.edges.map {
+ case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos)
+ }.sortBy(x => (x._1, x._2)).collect()
+ val sameEdge = sameGraph.edges.map {
+ case Edge(sid: Long, did: Long, nos: Double) => (sid, did, nos)
+ }.sortBy(x => (x._1, x._2)).collect()
+ assert(edge === sameEdge)
} finally {
Utils.deleteRecursively(tempDir1)
Utils.deleteRecursively(tempDir2)