diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-04-13 13:23:10 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-13 13:23:10 -0700 |
commit | a91aaf5a8cca18811c0cccc20f4e77f36231b344 (patch) | |
tree | 52ba43674f88ad8f5eb17021e0b15fe75696d4c2 /mllib | |
parent | 0d17593b32c12c3e39575430aa85cf20e56fae6a (diff) | |
download | spark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.tar.gz spark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.tar.bz2 spark-a91aaf5a8cca18811c0cccc20f4e77f36231b344.zip |
[SPARK-14375][ML] Unit test for spark.ml KMeansSummary
## What changes were proposed in this pull request?
* Modify ```KMeansSummary.clusterSizes``` method to make it robust to empty clusters.
* Add unit test for spark.ml ```KMeansSummary```.
* Add Since tag.
## How was this patch tested?
unit tests.
cc jkbradley
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #12254 from yanboliang/spark-14375.
Diffstat (limited to 'mllib')
3 files changed, 47 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index d716bc6887..b324196842 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -144,6 +144,12 @@ class KMeansModel private[ml] ( } /** + * Return true if there exists summary of model. + */ + @Since("2.0.0") + def hasSummary: Boolean = trainingSummary.nonEmpty + + /** * Gets summary of model on training set. An exception is * thrown if `trainingSummary == None`. */ @@ -267,7 +273,8 @@ class KMeans @Since("1.5.0") ( .setEpsilon($(tol)) val parentModel = algo.run(rdd) val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) - val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol)) + val summary = new KMeansSummary( + model.transform(dataset), $(predictionCol), $(featuresCol), $(k)) model.setSummary(summary) } @@ -284,10 +291,22 @@ object KMeans extends DefaultParamsReadable[KMeans] { override def load(path: String): KMeans = super.load(path) } +/** + * :: Experimental :: + * Summary of KMeans. + * + * @param predictions [[DataFrame]] produced by [[KMeansModel.transform()]] + * @param predictionCol Name for column of predicted clusters in `predictions` + * @param featuresCol Name for column of features in `predictions` + * @param k Number of clusters + */ +@Since("2.0.0") +@Experimental class KMeansSummary private[clustering] ( @Since("2.0.0") @transient val predictions: DataFrame, @Since("2.0.0") val predictionCol: String, - @Since("2.0.0") val featuresCol: String) extends Serializable { + @Since("2.0.0") val featuresCol: String, + @Since("2.0.0") val k: Int) extends Serializable { /** * Cluster centers of the transformed data. @@ -296,11 +315,15 @@ class KMeansSummary private[clustering] ( @transient lazy val cluster: DataFrame = predictions.select(predictionCol) /** - * Size of each cluster. + * Size of (number of data points in) each cluster. */ @Since("2.0.0") - lazy val clusterSizes: Array[Int] = cluster.rdd.map { - case Row(clusterIdx: Int) => (clusterIdx, 1) - }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2) + lazy val clusterSizes: Array[Long] = { + val sizes = Array.fill[Long](k)(0) + cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach { + case Row(cluster: Int, count: Long) => sizes(cluster) = count + } + sizes + } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala index ee513579ce..9e2b81ee20 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala @@ -37,7 +37,7 @@ private[r] class KMeansWrapper private ( lazy val k: Int = kMeansModel.getK - lazy val size: Array[Int] = kMeansModel.summary.clusterSizes + lazy val size: Array[Long] = kMeansModel.summary.clusterSizes lazy val cluster: DataFrame = kMeansModel.summary.cluster diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index 2076c745e2..2ca386e422 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR } } - test("fit & transform") { + test("fit, transform, and summary") { val predictionColName = "kmeans_prediction" val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1) val model = kmeans.fit(dataset) @@ -99,6 +99,22 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(clusters === Set(0, 1, 2, 3, 4)) assert(model.computeCost(dataset) < 0.1) assert(model.hasParent) + + // Check validity of model summary + val numRows = dataset.count() + assert(model.hasSummary) + val summary: KMeansSummary = model.summary + assert(summary.predictionCol === predictionColName) + assert(summary.featuresCol === "features") + assert(summary.predictions.count() === numRows) + for (c <- Array(predictionColName, "features")) { + assert(summary.predictions.columns.contains(c)) + } + assert(summary.cluster.columns === Array(predictionColName)) + val clusterSizes = summary.clusterSizes + assert(clusterSizes.length === k) + assert(clusterSizes.sum === numRows) + assert(clusterSizes.forall(_ >= 0)) } test("read/write") { |