diff options
author | Yuhao Yang <hhbyyh@gmail.com> | 2016-03-31 11:12:40 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-31 11:12:40 -0700 |
commit | a0a1991580ed24230f88cae9f5a4dfbe58f03b28 (patch) | |
tree | d0d230f6116b8b6cd1bff0da73b1a782edb04f68 /mllib/src/test/scala | |
parent | 3b3cc76004438a942ecea752db39f3a904a52462 (diff) | |
download | spark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.tar.gz spark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.tar.bz2 spark-a0a1991580ed24230f88cae9f5a4dfbe58f03b28.zip |
[SPARK-13782][ML] Model export/import for spark.ml: BisectingKMeans
## What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-13782
Model export/import for BisectingKMeans in spark.ml and mllib
## How was this patch tested?
unit tests
Author: Yuhao Yang <hhbyyh@gmail.com>
Closes #11933 from hhbyyh/bisectingsave.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala | 22 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala | 18 |
2 files changed, 39 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index b719a8c7e7..18f2c994b4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame -class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { +class BisectingKMeansSuite + extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { final val k = 5 @transient var dataset: DataFrame = _ @@ -84,4 +86,22 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) } + + test("read/write") { + def checkModelData(model: BisectingKMeansModel, model2: BisectingKMeansModel): Unit = { + assert(model.clusterCenters === model2.clusterCenters) + } + val bisectingKMeans = new BisectingKMeans() + testEstimatorAndModelReadWrite( + bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + } +} + +object BisectingKMeansSuite { + val allParamSettings: Map[String, Any] = Map( + "k" -> 3, + "maxIter" -> 2, + "seed" -> -1L, + "minDivisibleClusterSize" -> 2.0 + ) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala index 41b9d5c0d9..35f7932ae8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -179,4 +180,21 @@ class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext { } } } + + test("BisectingKMeans model save/load") { + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + val points = (1 until 8).map(i => Vectors.dense(i)) + val data = sc.parallelize(points, 2) + val model = new BisectingKMeans().run(data) + try { + model.save(sc, path) + val sameModel = BisectingKMeansModel.load(sc, path) + assert(model.k === sameModel.k) + model.clusterCenters.zip(sameModel.clusterCenters).foreach(c => c._1 === c._2) + } finally { + Utils.deleteRecursively(tempDir) + } + } } |