diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-04-13 13:23:10 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-13 13:23:10 -0700 |
commit | a91aaf5a8cca18811c0cccc20f4e77f36231b344 (patch) | |
tree | 52ba43674f88ad8f5eb17021e0b15fe75696d4c2 /mllib/src/test | |
parent | 0d17593b32c12c3e39575430aa85cf20e56fae6a (diff) | |
download | spark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.tar.gz spark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.tar.bz2 spark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.zip |
[SPARK-14375][ML] Unit test for spark.ml KMeansSummary
## What changes were proposed in this pull request?
* Modify ```KMeansSummary.clusterSizes``` method to make it robust to empty clusters.
* Add unit test for spark.ml ```KMeansSummary```.
* Add Since tag.
## How was this patch tested?
unit tests.
cc jkbradley
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #12254 from yanboliang/spark-14375.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2076c745e2..2ca386e422 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } } - test("fit & transform") { + test("fit, transform, and summary") { val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = kmeans.fit(dataset) @@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: KMeansSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) } test("read/write") { |