diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index b719a8c7e7..18f2c994b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame -class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class BisectingKMeansSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 @transient var dataset: DataFrame = _ @@ -84,4 +86,22 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) } + + test("read/write") { + def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val bisectingKMeans = new BisectingKMeans() + testEstimatorAndModelReadWrite( + bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + } +} + +object BisectingKMeansSuite { + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "seed" -> -1L, + "minDivisibleClusterSize" -> 2.0 + ) } |