aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-03-11 00:24:55 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-11 00:24:55 -0700
commit2d4e00efe2cf179935ae108a68f28edf6e5a1628 (patch)
tree2980f1d4c8467c4204ddc519250bd007c8ebcdd7 /mllib/src/test
parent2672374110d58e45ffae2408e74b96613deddda3 (diff)
downloadspark-2d4e00efe2cf179935ae108a68f28edf6e5a1628.tar.gz
spark-2d4e00efe2cf179935ae108a68f28edf6e5a1628.tar.bz2
spark-2d4e00efe2cf179935ae108a68f28edf6e5a1628.zip
[SPARK-5986][MLLib] Add save/load for k-means
This PR adds save/load for K-means as described in SPARK-5986. Python version will be added in another PR. Author: Xusen Yin <yinxusen@gmail.com> Closes #4951 from yinxusen/SPARK-5986 and squashes the following commits: 6dd74a0 [Xusen Yin] rewrite some functions and classes cd390fd [Xusen Yin] add indexed point b144216 [Xusen Yin] remove invalid comments dce7055 [Xusen Yin] add save/load for k-means for SPARK-5986
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala44
1 files changed, 43 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
index caee591700..7bf250eb5a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala
@@ -21,9 +21,10 @@ import scala.util.Random
import org.scalatest.FunSuite
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class KMeansSuite extends FunSuite with MLlibTestSparkContext {
@@ -257,6 +258,47 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext {
assert(predicts(0) != predicts(3))
}
}
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Array(true, false).foreach { case selector =>
+ val model = KMeansSuite.createModel(10, 3, selector)
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = KMeansModel.load(sc, path)
+ KMeansSuite.checkEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
+}
+
+object KMeansSuite extends FunSuite {
+ def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = {
+ val singlePoint = isSparse match {
+ case true =>
+ Vectors.sparse(dim, Array.empty[Int], Array.empty[Double])
+ case _ =>
+ Vectors.dense(Array.fill[Double](dim)(0.0))
+ }
+ new KMeansModel(Array.fill[Vector](k)(singlePoint))
+ }
+
+ def checkEqual(a: KMeansModel, b: KMeansModel): Unit = {
+ assert(a.k === b.k)
+ a.clusterCenters.zip(b.clusterCenters).foreach {
+ case (ca: SparseVector, cb: SparseVector) =>
+ assert(ca === cb)
+ case (ca: DenseVector, cb: DenseVector) =>
+ assert(ca === cb)
+ case _ =>
+ throw new AssertionError("checkEqual failed since the two clusters were not identical.\n")
+ }
+ }
}
class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext {