aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-10-14 04:25:14 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-10-14 04:25:14 -0700
commita1b136d05c6c458ae8211b0844bfc98d7693fa42 (patch)
treeb9fef5799c45c13fd3979a7e1d0be9853377088f /mllib
parent1db8feab8c564053c05e8bdc1a7f5026fd637d4f (diff)
downloadspark-a1b136d05c6c458ae8211b0844bfc98d7693fa42.tar.gz
spark-a1b136d05c6c458ae8211b0844bfc98d7693fa42.tar.bz2
spark-a1b136d05c6c458ae8211b0844bfc98d7693fa42.zip
[SPARK-14634][ML] Add BisectingKMeansSummary
## What changes were proposed in this pull request? Add BisectingKMeansSummary ## How was this patch tested? unit test Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #12394 from zhengruifeng/biKMSummary.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala74
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala2
4 files changed, 91 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index a97bd0fb16..add8ee2a4f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
@@ -127,6 +128,29 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)
+
+ private var trainingSummary: Option[BisectingKMeansSummary] = None
+
+ private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = {
+ this.trainingSummary = Some(summary)
+ this
+ }
+
+ /**
+ * Return true if there exists summary of model.
+ */
+ @Since("2.1.0")
+ def hasSummary: Boolean = trainingSummary.nonEmpty
+
+ /**
+ * Gets summary of model on training set. An exception is
+ * thrown if `trainingSummary == None`.
+ */
+ @Since("2.1.0")
+ def summary: BisectingKMeansSummary = trainingSummary.getOrElse {
+ throw new SparkException(
+ s"No training summary available for the ${this.getClass.getSimpleName}")
+ }
}
object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
@@ -228,14 +252,22 @@ class BisectingKMeans @Since("2.0.0") (
case Row(point: Vector) => OldVectors.fromML(point)
}
+ val instr = Instrumentation.create(this, rdd)
+ instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize)
+
val bkm = new MLlibBisectingKMeans()
.setK($(k))
.setMaxIterations($(maxIter))
.setMinDivisibleClusterSize($(minDivisibleClusterSize))
.setSeed($(seed))
val parentModel = bkm.run(rdd)
- val model = new BisectingKMeansModel(uid, parentModel)
- copyValues(model.setParent(this))
+ val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
+ val summary = new BisectingKMeansSummary(
+ model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
+ model.setSummary(summary)
+ val m = model.setSummary(summary)
+ instr.logSuccess(m)
+ m
}
@Since("2.0.0")
@@ -251,3 +283,41 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] {
@Since("2.0.0")
override def load(path: String): BisectingKMeans = super.load(path)
}
+
+
+/**
+ * :: Experimental ::
+ * Summary of BisectingKMeans.
+ *
+ * @param predictions [[DataFrame]] produced by [[BisectingKMeansModel.transform()]]
+ * @param predictionCol Name for column of predicted clusters in `predictions`
+ * @param featuresCol Name for column of features in `predictions`
+ * @param k Number of clusters
+ */
+@Since("2.1.0")
+@Experimental
+class BisectingKMeansSummary private[clustering] (
+ @Since("2.1.0") @transient val predictions: DataFrame,
+ @Since("2.1.0") val predictionCol: String,
+ @Since("2.1.0") val featuresCol: String,
+ @Since("2.1.0") val k: Int) extends Serializable {
+
+ /**
+ * Cluster centers of the transformed data.
+ */
+ @Since("2.1.0")
+ @transient lazy val cluster: DataFrame = predictions.select(predictionCol)
+
+ /**
+ * Size of (number of data points in) each cluster.
+ */
+ @Since("2.1.0")
+ lazy val clusterSizes: Array[Long] = {
+ val sizes = Array.fill[Long](k)(0)
+ cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
+ case Row(cluster: Int, count: Long) => sizes(cluster) = count
+ }
+ sizes
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
index 4f7d4418a8..f2368a9f8d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -68,7 +68,7 @@ class BisectingKMeansSuite
}
}
- test("fit & transform") {
+ test("fit, transform and summary") {
val predictionColName = "bisecting_kmeans_prediction"
val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = bkm.fit(dataset)
@@ -85,6 +85,22 @@ class BisectingKMeansSuite
assert(clusters === Set(0, 1, 2, 3, 4))
assert(model.computeCost(dataset) < 0.1)
assert(model.hasParent)
+
+ // Check validity of model summary
+ val numRows = dataset.count()
+ assert(model.hasSummary)
+ val summary: BisectingKMeansSummary = model.summary
+ assert(summary.predictionCol === predictionColName)
+ assert(summary.featuresCol === "features")
+ assert(summary.predictions.count() === numRows)
+ for (c <- Array(predictionColName, "features")) {
+ assert(summary.predictions.columns.contains(c))
+ }
+ assert(summary.cluster.columns === Array(predictionColName))
+ val clusterSizes = summary.clusterSizes
+ assert(clusterSizes.length === k)
+ assert(clusterSizes.sum === numRows)
+ assert(clusterSizes.forall(_ >= 0))
}
test("read/write") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
index 04366f5250..003fa6abf6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala
@@ -70,7 +70,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext
}
}
- test("fit, transform, and summary") {
+ test("fit, transform and summary") {
val predictionColName = "gm_prediction"
val probabilityColName = "gm_probability"
val gm = new GaussianMixture().setK(k).setMaxIter(2).setPredictionCol(predictionColName)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index c9ba5a288a..ca39265355 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -82,7 +82,7 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
}
}
- test("fit, transform, and summary") {
+ test("fit, transform and summary") {
val predictionColName = "kmeans_prediction"
val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
val model = kmeans.fit(dataset)