aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala22
1 files changed, 21 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index b719a8c7e7..18f2c994b4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -18,10 +18,12 @@
package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
-class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+class BisectingKMeansSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
final val k = 5
@transient var dataset: DataFrame = _
@@ -84,4 +86,22 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
}
+
+ test("read/write") {
+ def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = {
+ assert(model.clusterCenters === model2.clusterCenters)
+ }
+ val bisectingKMeans = new BisectingKMeans()
+ testEstimatorAndModelReadWrite(
+ bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
+ }
+}
+
+object BisectingKMeansSuite {
+ val allParamSettings: Map[String, Any] = Map(
+ "k" -> 3,
+ "maxIter" -> 2,
+ "seed" -> -1L,
+ "minDivisibleClusterSize" -> 2.0
+ )
}