aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala22
1 files changed, 19 insertions, 3 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c684bc11cc..2ca386e422 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -22,14 +22,14 @@ import org.apache.spark.ml.util.DefaultReadWriteTest
import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{DataFrame, Dataset, SQLContext}
private[clustering] case class TestRow(features: Vector)
class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
final val k = 5
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
}
- test("fit & transform") {
+ test("fit, transform, and summary") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)
@@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
+
+ // Check validity of model summary
+ val numRows = dataset.count()
+ assert(model.hasSummary)
+ val summary: KMeansSummary = model.summary
+ assert(summary.predictionCol === predictionColName)
+ assert(summary.featuresCol === "features")
+ assert(summary.predictions.count() === numRows)
+ for (c <- Array(predictionColName, "features")) {
+ assert(summary.predictions.columns.contains(c))
+ }
+ assert(summary.cluster.columns === Array(predictionColName))
+ val clusterSizes = summary.clusterSizes
+ assert(clusterSizes.length === k)
+ assert(clusterSizes.sum === numRows)
+ assert(clusterSizes.forall(_ >= 0))
}
test("read/write") {