aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-03-31 11:12:40 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-31 11:12:40 -0700
commita0a1991580ed24230f88cae9f5a4dfbe58f03b28 (patch)
treed0d230f6116b8b6cd1bff0da73b1a782edb04f68 /mllib/src/test
parent3b3cc76004438a942ecea752db39f3a904a52462 (diff)
downloadspark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.tar.gz
spark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.tar.bz2
spark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.zip
[SPARK-13782][ML] Model export/import for spark.ml: BisectingKMeans
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-13782 Model export/import for BisectingKMeans in spark.ml and mllib ## How was this patch tested? unit tests Author: Yuhao Yang <hhbyyh@gmail.com> Closes #11933 from hhbyyh/bisectingsave.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala22
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala18
2 files changed, 39 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index b719a8c7e7..18f2c994b4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -18,10 +18,12 @@
package org.apache.spark.ml.clustering
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
-class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+class BisectingKMeansSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
final val k = 5
@transient var dataset: DataFrame = _
@@ -84,4 +86,22 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
}
+
+ test("read/write") {
+ def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = {
+ assert(model.clusterCenters === model2.clusterCenters)
+ }
+ val bisectingKMeans = new BisectingKMeans()
+ testEstimatorAndModelReadWrite(
+ bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData)
+ }
+}
+
+object BisectingKMeansSuite {
+ val allParamSettings: Map[String, Any] = Map(
+ "k" -> 3,
+ "maxIter" -> 2,
+ "seed" -> -1L,
+ "minDivisibleClusterSize" -> 2.0
+ )
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
index 41b9d5c0d9..35f7932ae8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -179,4 +180,21 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
}
+
+ test("BisectingKMeans model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val points = (1 until 8).map(i => Vectors.dense(i))
+ val data = sc.parallelize(points, 2)
+ val model = new BisectingKMeans().run(data)
+ try {
+ model.save(sc, path)
+ val sameModel = BisectingKMeansModel.load(sc, path)
+ assert(model.k === sameModel.k)
+ model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}