aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-10-17 10:04:19 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-10-17 10:04:19 -0700
commite1e77b22b3b577909a12c3aa898eb53be02267fd (patch)
treeeb1b5f730b6833dbfbc7f8c67b0d65801a28ee97 /mllib/src
parent8ac71d62d976bbfd0159cac6816dd8fa580ae1cb (diff)
downloadspark-e1e77b22b3b577909a12c3aa898eb53be02267fd.tar.gz
spark-e1e77b22b3b577909a12c3aa898eb53be02267fd.tar.bz2
spark-e1e77b22b3b577909a12c3aa898eb53be02267fd.zip
[SPARK-11029] [ML] Add computeCost to KMeansModel in spark.ml
jira: https://issues.apache.org/jira/browse/SPARK-11029 We should add a method analogous to spark.mllib.clustering.KMeansModel.computeCost to spark.ml.clustering.KMeansModel. This will be a temp fix until we have proper evaluators defined for clustering. Author: Yuhao Yang <hhbyyh@gmail.com> Author: yuhaoyang <yuhao@zhanglipings-iMac.local> Closes #9073 from hhbyyh/computeCost.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala1
2 files changed, 13 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index f40ab71fb2..509be63002 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -117,6 +117,18 @@ class KMeansModel private[ml] (
@Since("1.5.0")
def clusterCenters: Array[Vector] = parentModel.clusterCenters
+
+ /**
+ * Return the K-means cost (sum of squared distances of points to their nearest center) for this
+ * model on the given data.
+ */
+ // TODO: Replace the temp fix when we have proper evaluators defined for clustering.
+ @Since("1.6.0")
+ def computeCost(dataset: DataFrame): Double = {
+ SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
+ val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
+ parentModel.computeCost(data)
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 688b0e31f9..c05f90550d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -104,5 +104,6 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet
assert(clusters.size === k)
assert(clusters === Set(0, 1, 2, 3, 4))
+ assert(model.computeCost(dataset) < 0.1)
}
}