aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-04-13 11:53:17 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-13 11:53:17 -0700
commit1e340c3ae4d5361d048a3d6990f144cfc923666f (patch)
tree81c481ecc9b909c252dc8f57e540b1839f193ebd /mllib/src/test
parent6cc5b3ed3c0c729f97956fa017d8eb7d6b43f90f (diff)
downloadspark-1e340c3ae4d5361d048a3d6990f144cfc923666f.tar.gz
spark-1e340c3ae4d5361d048a3d6990f144cfc923666f.tar.bz2
spark-1e340c3ae4d5361d048a3d6990f144cfc923666f.zip
[SPARK-5988][MLlib] add save/load for PowerIterationClusteringModel
See JIRA issue [SPARK-5988](https://issues.apache.org/jira/browse/SPARK-5988). Author: Xusen Yin <yinxusen@gmail.com> Closes #5450 from yinxusen/SPARK-5988 and squashes the following commits: cb1ecfa [Xusen Yin] change Assignment into case class b1dd24c [Xusen Yin] add test suite 63c3923 [Xusen Yin] add save load for power iteration clustering
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala34
1 files changed, 34 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 6315c03a70..6d6fe6fe46 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -18,12 +18,15 @@
package org.apache.spark.mllib.clustering
import scala.collection.mutable
+import scala.util.Random
import org.scalatest.FunSuite
+import org.apache.spark.SparkContext
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
@@ -110,4 +113,35 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
assert(x ~== u1(i.toInt) absTol 1e-14)
}
}
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ val model = PowerIterationClusteringSuite.createModel(sc, 3, 10)
+ try {
+ model.save(sc, path)
+ val sameModel = PowerIterationClusteringModel.load(sc, path)
+ PowerIterationClusteringSuite.checkEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+}
+
+object PowerIterationClusteringSuite extends FunSuite {
+ def createModel(sc: SparkContext, k: Int, nPoints: Int): PowerIterationClusteringModel = {
+ val assignments = sc.parallelize(
+ (0 until nPoints).map(p => PowerIterationClustering.Assignment(p, Random.nextInt(k))))
+ new PowerIterationClusteringModel(k, assignments)
+ }
+
+ def checkEqual(a: PowerIterationClusteringModel, b: PowerIterationClusteringModel): Unit = {
+ assert(a.k === b.k)
+
+ val aAssignments = a.assignments.map(x => (x.id, x.cluster))
+ val bAssignments = b.assignments.map(x => (x.id, x.cluster))
+ val unequalElements = aAssignments.join(bAssignments).filter {
+ case (id, (c1, c2)) => c1 != c2 }.count()
+ assert(unequalElements === 0L)
+ }
}