aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-13 13:23:10 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 13:23:10 -0700
commita91aaf5a8cca18811c0cccc20f4e77f36231b344 (patch)
tree52ba43674f88ad8f5eb17021e0b15fe75696d4c2 /mllib/src/test
parent0d17593b32c12c3e39575430aa85cf20e56fae6a (diff)
downloadspark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.tar.gz
spark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.tar.bz2
spark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.zip
[SPARK-14375][ML] Unit test for spark.ml KMeansSummary
## What changes were proposed in this pull request? * Modify ```KMeansSummary.clusterSizes``` method to make it robust to empty clusters. * Add unit test for spark.ml ```KMeansSummary```. * Add Since tag. ## How was this patch tested? unit tests. cc jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Closes #12254 from yanboliang/spark-14375.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala18
1 files changed, 17 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 2076c745e2..2ca386e422 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
}
- test("fit & transform") {
+ test("fit, transform, and summary") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)
@@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
+
+ // Check validity of model summary
+ val numRows = dataset.count()
+ assert(model.hasSummary)
+ val summary: KMeansSummary = model.summary
+ assert(summary.predictionCol === predictionColName)
+ assert(summary.featuresCol === "features")
+ assert(summary.predictions.count() === numRows)
+ for (c <- Array(predictionColName, "features")) {
+ assert(summary.predictions.columns.contains(c))
+ }
+ assert(summary.cluster.columns === Array(predictionColName))
+ val clusterSizes = summary.clusterSizes
+ assert(clusterSizes.length === k)
+ assert(clusterSizes.sum === numRows)
+ assert(clusterSizes.forall(_ >= 0))
}
test("read/write") {