aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-07-30 14:08:59 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-30 14:08:59 -0700
commit89cda69ecd5ef942a68ad13fc4e1f4184010f087 (patch)
treef17f700ce4b42dc4b683e67e777d62407bc66abe /mllib
parent1abf7dc16ca1ba1777fe874c8b81fe6f2b0a6de5 (diff)
downloadspark-89cda69ecd5ef942a68ad13fc4e1f4184010f087.tar.gz
spark-89cda69ecd5ef942a68ad13fc4e1f4184010f087.tar.bz2
spark-89cda69ecd5ef942a68ad13fc4e1f4184010f087.zip
[SPARK-9454] Change LDASuite tests to use vector comparisons
jkbradley Changes the current hacky string-comparison for vector compares. Author: Feynman Liang <fliang@databricks.com> Closes #7775 from feynmanliang/SPARK-9454-ldasuite-vector-compare and squashes the following commits: bd91a82 [Feynman Liang] Remove println 905c76e [Feynman Liang] Fix string compare in distributed EM 2f24c13 [Feynman Liang] Improve LDASuite tests
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala33
1 files changed, 14 insertions, 19 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
index d74482d3a7..c43e1e575c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala
@@ -83,21 +83,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.topicsMatrix === localModel.topicsMatrix)
// Check: topic summaries
- // The odd decimal formatting and sorting is a hack to do a robust comparison.
- val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) =>
- // cut values to 3 digits after the decimal place
- terms.zip(termWeights).map { case (term, weight) =>
- ("%.3f".format(weight).toDouble, term.toInt)
- }
- }.sortBy(_.mkString(""))
- val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
- // cut values to 3 digits after the decimal place
- terms.zip(termWeights).map { case (term, weight) =>
- ("%.3f".format(weight).toDouble, term.toInt)
- }
- }.sortBy(_.mkString(""))
- roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) =>
- assert(t1 === t2)
+ val topicSummary = model.describeTopics().map { case (terms, termWeights) =>
+ Vectors.sparse(tinyVocabSize, terms, termWeights)
+ }.sortBy(_.toString)
+ val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) =>
+ Vectors.sparse(tinyVocabSize, terms, termWeights)
+ }.sortBy(_.toString)
+ topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) =>
+ assert(topics ~== topicsLocal absTol 0.01)
}
// Check: per-doc topic distributions
@@ -197,10 +190,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
// verify the result, Note this generate the identical result as
// [[https://github.com/Blei-Lab/onlineldavb]]
- val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
- val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ")
- assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1)
- assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2)
+ val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t)
+ val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t)
+ val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950)
+ val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050)
+ assert(topic1 ~== expectedTopic1 absTol 0.01)
+ assert(topic2 ~== expectedTopic2 absTol 0.01)
}
test("OnlineLDAOptimizer with toy data") {