aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala1
2 files changed, 13 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index f40ab71fb2..509be63002 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -117,6 +117,18 @@ class KMeansModel private[ml] (
@Since("1.5.0")
def clusterCenters: Array[Vector] = parentModel.clusterCenters
+
+ /**
+ * Return the K-means cost (sum of squared distances of points to their nearest center) for this
+ * model on the given data.
+ */
+ // TODO: Replace the temp fix when we have proper evaluators defined for clustering.
+ @Since("1.6.0")
+ def computeCost(dataset: DataFrame): Double = {
+ SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
+ val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
+ parentModel.computeCost(data)
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 688b0e31f9..c05f90550d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -104,5 +104,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet
assert(clusters.size === k)
assert(clusters === Set(0, 1, 2, 3, 4))
+ assert(model.computeCost(dataset) < 0.1)
}
}