diff options
author | Xusen Yin <yinxusen@gmail.com> | 2015-03-11 00:24:55 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-11 00:24:55 -0700 |
commit | 2d4e00efe2cf179935ae108a68f28edf6e5a1628 (patch) | |
tree | 2980f1d4c8467c4204ddc519250bd007c8ebcdd7 /mllib/src/test | |
parent | 2672374110d58e45ffae2408e74b96613deddda3 (diff) | |
download | spark-2d4e00efe2cf179935ae108a68f28edf6e5a1628.tar.gz spark-2d4e00efe2cf179935ae108a68f28edf6e5a1628.tar.bz2 spark-2d4e00efe2cf179935ae108a68f28edf6e5a1628.zip |
[SPARK-5986][MLLib] Add save/load for k-means
This PR adds save/load for K-means as described in SPARK-5986. Python version will be added in another PR.
Author: Xusen Yin <yinxusen@gmail.com>
Closes #4951 from yinxusen/SPARK-5986 and squashes the following commits:
6dd74a0 [Xusen Yin] rewrite some functions and classes
cd390fd [Xusen Yin] add indexed point
b144216 [Xusen Yin] remove invalid comments
dce7055 [Xusen Yin] add save/load for k-means for SPARK-5986
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala | 44 |
1 files changed, 43 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala index caee591700..7bf250eb5a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala @@ -21,9 +21,10 @@ import scala.util.Random import org.scalatest.FunSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class KMeansSuite extends FunSuite with MLlibTestSparkContext { @@ -257,6 +258,47 @@ class KMeansSuite extends FunSuite with MLlibTestSparkContext { assert(predicts(0) != predicts(3)) } } + + test("model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + Array(true, false).foreach { case selector => + val model = KMeansSuite.createModel(10, 3, selector) + // Save model, load it back, and compare. + try { + model.save(sc, path) + val sameModel = KMeansModel.load(sc, path) + KMeansSuite.checkEqual(model, sameModel) + } finally { + Utils.deleteRecursively(tempDir) + } + } + } +} + +object KMeansSuite extends FunSuite { + def createModel(dim: Int, k: Int, isSparse: Boolean): KMeansModel = { + val singlePoint = isSparse match { + case true => + Vectors.sparse(dim, Array.empty[Int], Array.empty[Double]) + case _ => + Vectors.dense(Array.fill[Double](dim)(0.0)) + } + new KMeansModel(Array.fill[Vector](k)(singlePoint)) + } + + def checkEqual(a: KMeansModel, b: KMeansModel): Unit = { + assert(a.k === b.k) + a.clusterCenters.zip(b.clusterCenters).foreach { + case (ca: SparseVector, cb: SparseVector) => + assert(ca === cb) + case (ca: DenseVector, cb: DenseVector) => + assert(ca === cb) + case _ => + throw new AssertionError("checkEqual failed since the two clusters were not identical.\n") + } + } } class KMeansClusterSuite extends FunSuite with LocalClusterSparkContext { |