diff options
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala | 22 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala | 18 |
2 files changed, 39 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index b719a8c7e7..18f2c994b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame -class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class BisectingKMeansSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 @transient var dataset: DataFrame = _ @@ -84,4 +86,22 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) } + + test("read/write") { + def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val bisectingKMeans = new BisectingKMeans() + testEstimatorAndModelReadWrite( + bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + } +} + +object BisectingKMeansSuite { + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "seed" -> -1L, + "minDivisibleClusterSize" -> 2.0 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala index 41b9d5c0d9..35f7932ae8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -179,4 +180,21 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("BisectingKMeans model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val points = (1 until 8).map(i => Vectors.dense(i)) + val data = sc.parallelize(points, 2) + val model = new BisectingKMeans().run(data) + try { + model.save(sc, path) + val sameModel = BisectingKMeansModel.load(sc, path) + assert(model.k === sameModel.k) + model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2) + } finally { + Utils.deleteRecursively(tempDir) + } + } } |