aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-13 13:23:10 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 13:23:10 -0700
commita91aaf5a8cca18811c0cccc20f4e77f36231b344 (patch)
tree52ba43674f88ad8f5eb17021e0b15fe75696d4c2 /mllib
parent0d17593b32c12c3e39575430aa85cf20e56fae6a (diff)
downloadspark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.tar.gz
spark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.tar.bz2
spark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.zip
[SPARK-14375][ML] Unit test for spark.ml KMeansSummary
## What changes were proposed in this pull request? * Modify ```KMeansSummary.clusterSizes``` method to make it robust to empty clusters. * Add unit test for spark.ml ```KMeansSummary```. * Add Since tag. ## How was this patch tested? unit tests. cc jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Closes #12254 from yanboliang/spark-14375.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala35
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala18
3 files changed, 47 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index d716bc6887..b324196842 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -144,6 +144,12 @@ class KMeansModel private[ml] (
}
/**
+ * Return true if there exists summary of model.
+ */
+ @Since("2.0.0")
+ def hasSummary: Boolean = trainingSummary.nonEmpty
+
+ /**
* Gets summary of model on training set. An exception is
* thrown if `trainingSummary == None`.
*/
@@ -267,7 +273,8 @@ class KMeans @Since("1.5.0") (
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
- val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol))
+ val summary = new KMeansSummary(
+ model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
model.setSummary(summary)
}
@@ -284,10 +291,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
override def load(path: String): KMeans = super.load(path)
}
+/**
+ * :: Experimental ::
+ * Summary of KMeans.
+ *
+ * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]]
+ * @param predictionCol Name for column of predicted clusters in `predictions`
+ * @param featuresCol Name for column of features in `predictions`
+ * @param k Number of clusters
+ */
+@Since("2.0.0")
+@Experimental
class KMeansSummary private[clustering] (
@Since("2.0.0") @transient val predictions: DataFrame,
@Since("2.0.0") val predictionCol: String,
- @Since("2.0.0") val featuresCol: String) extends Serializable {
+ @Since("2.0.0") val featuresCol: String,
+ @Since("2.0.0") val k: Int) extends Serializable {
/**
* Cluster centers of the transformed data.
@@ -296,11 +315,15 @@ class KMeansSummary private[clustering] (
@transient lazy val cluster: DataFrame = predictions.select(predictionCol)
/**
- * Size of each cluster.
+ * Size of (number of data points in) each cluster.
*/
@Since("2.0.0")
- lazy val clusterSizes: Array[Int] = cluster.rdd.map {
- case Row(clusterIdx: Int) => (clusterIdx, 1)
- }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
+ lazy val clusterSizes: Array[Long] = {
+ val sizes = Array.fill[Long](k)(0)
+ cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
+ case Row(cluster: Int, count: Long) => sizes(cluster) = count
+ }
+ sizes
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index ee513579ce..9e2b81ee20 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -37,7 +37,7 @@ private[r] class KMeansWrapper private (
lazy val k: Int = kMeansModel.getK
- lazy val size: Array[Int] = kMeansModel.summary.clusterSizes
+ lazy val size: Array[Long] = kMeansModel.summary.clusterSizes
lazy val cluster: DataFrame = kMeansModel.summary.cluster
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 2076c745e2..2ca386e422 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
}
- test("fit & transform") {
+ test("fit, transform, and summary") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)
@@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
+
+ // Check validity of model summary
+ val numRows = dataset.count()
+ assert(model.hasSummary)
+ val summary: KMeansSummary = model.summary
+ assert(summary.predictionCol === predictionColName)
+ assert(summary.featuresCol === "features")
+ assert(summary.predictions.count() === numRows)
+ for (c <- Array(predictionColName, "features")) {
+ assert(summary.predictions.columns.contains(c))
+ }
+ assert(summary.cluster.columns === Array(predictionColName))
+ val clusterSizes = summary.clusterSizes
+ assert(clusterSizes.length === k)
+ assert(clusterSizes.sum === numRows)
+ assert(clusterSizes.forall(_ >= 0))
}
test("read/write") {