diff options
author | Feynman Liang <fliang@databricks.com> | 2015-07-30 14:08:59 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-07-30 14:08:59 -0700 |
commit | 89cda69ecd5ef942a68ad13fc4e1f4184010f087 (patch) | |
tree | f17f700ce4b42dc4b683e67e777d62407bc66abe | |
parent | 1abf7dc16ca1ba1777fe874c8b81fe6f2b0a6de5 (diff) | |
download | spark-89cda69ecd5ef942a68ad13fc4e1f4184010f087.tar.gz spark-89cda69ecd5ef942a68ad13fc4e1f4184010f087.tar.bz2 spark-89cda69ecd5ef942a68ad13fc4e1f4184010f087.zip |
[SPARK-9454] Change LDASuite tests to use vector comparisons
jkbradley Changes the current hacky string-comparison for vector compares.
Author: Feynman Liang <fliang@databricks.com>
Closes #7775 from feynmanliang/SPARK-9454-ldasuite-vector-compare and squashes the following commits:
bd91a82 [Feynman Liang] Remove println
905c76e [Feynman Liang] Fix string compare in distributed EM
2f24c13 [Feynman Liang] Improve LDASuite tests
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala | 33 |
1 files changed, 14 insertions, 19 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index d74482d3a7..c43e1e575c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -83,21 +83,14 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.topicsMatrix === localModel.topicsMatrix) // Check: topic summaries - // The odd decimal formatting and sorting is a hack to do a robust comparison. - val roundedTopicSummary = model.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - val roundedLocalTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => - // cut values to 3 digits after the decimal place - terms.zip(termWeights).map { case (term, weight) => - ("%.3f".format(weight).toDouble, term.toInt) - } - }.sortBy(_.mkString("")) - roundedTopicSummary.zip(roundedLocalTopicSummary).foreach { case (t1, t2) => - assert(t1 === t2) + val topicSummary = model.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + val localTopicSummary = localModel.describeTopics().map { case (terms, termWeights) => + Vectors.sparse(tinyVocabSize, terms, termWeights) + }.sortBy(_.toString) + topicSummary.zip(localTopicSummary).foreach { case (topics, topicsLocal) => + assert(topics ~== topicsLocal absTol 0.01) } // Check: per-doc topic distributions @@ -197,10 +190,12 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext { // verify the result, Note this generate the identical result as // [[https://github.com/Blei-Lab/onlineldavb]] - val topic1 = op.getLambda(0, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") - val topic2 = op.getLambda(1, ::).inner.toArray.map("%.4f".format(_)).mkString(", ") - assert("1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950" == topic1) - assert("0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050" == topic2) + val topic1: Vector = Vectors.fromBreeze(op.getLambda(0, ::).t) + val topic2: Vector = Vectors.fromBreeze(op.getLambda(1, ::).t) + val expectedTopic1 = Vectors.dense(1.1101, 1.2076, 1.3050, 0.8899, 0.7924, 0.6950) + val expectedTopic2 = Vectors.dense(0.8899, 0.7924, 0.6950, 1.1101, 1.2076, 1.3050) + assert(topic1 ~== expectedTopic1 absTol 0.01) + assert(topic2 ~== expectedTopic2 absTol 0.01) } test("OnlineLDAOptimizer with toy data") { |