aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala18
1 files changed, 17 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 2076c745e2..2ca386e422 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
}
- test("fit & transform") {
+ test("fit, transform, and summary") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)
@@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
+
+ // Check validity of model summary
+ val numRows = dataset.count()
+ assert(model.hasSummary)
+ val summary: KMeansSummary = model.summary
+ assert(summary.predictionCol === predictionColName)
+ assert(summary.featuresCol === "features")
+ assert(summary.predictions.count() === numRows)
+ for (c <- Array(predictionColName, "features")) {
+ assert(summary.predictions.columns.contains(c))
+ }
+ assert(summary.cluster.columns === Array(predictionColName))
+ val clusterSizes = summary.clusterSizes
+ assert(clusterSizes.length === k)
+ assert(clusterSizes.sum === numRows)
+ assert(clusterSizes.forall(_ >= 0))
}
test("read/write") {