aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-10-14 04:25:14 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-10-14 04:25:14 -0700
commita1b136d05c6c458ae8211b0844bfc98d7693fa42 (patch)
treeb9fef5799c45c13fd3979a7e1d0be9853377088f /mllib/src/main
parent1db8feab8c564053c05e8bdc1a7f5026fd637d4f (diff)
downloadspark-a1b136d05c6c458ae8211b0844bfc98d7693fa42.tar.gz
spark-a1b136d05c6c458ae8211b0844bfc98d7693fa42.tar.bz2
spark-a1b136d05c6c458ae8211b0844bfc98d7693fa42.zip
[SPARK-14634][ML] Add BisectingKMeansSummary
## What changes were proposed in this pull request? Add BisectingKMeansSummary ## How was this patch tested? unit test Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #12394 from zhengruifeng/biKMSummary.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala74
1 files changed, 72 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index a97bd0fb16..add8ee2a4f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
@@ -127,6 +128,29 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def write: MLWriter = new BisectingKMeansModel.BisectingKMeansModelWriter(this)
+
+ private var trainingSummary: Option[BisectingKMeansSummary] = None
+
+ private[clustering] def setSummary(summary: BisectingKMeansSummary): this.type = {
+ this.trainingSummary = Some(summary)
+ this
+ }
+
+ /**
+ * Return true if there exists summary of model.
+ */
+ @Since("2.1.0")
+ def hasSummary: Boolean = trainingSummary.nonEmpty
+
+ /**
+ * Gets summary of model on training set. An exception is
+ * thrown if `trainingSummary == None`.
+ */
+ @Since("2.1.0")
+ def summary: BisectingKMeansSummary = trainingSummary.getOrElse {
+ throw new SparkException(
+ s"No training summary available for the ${this.getClass.getSimpleName}")
+ }
}
object BisectingKMeansModel extends MLReadable[BisectingKMeansModel] {
@@ -228,14 +252,22 @@ class BisectingKMeans @Since("2.0.0") (
case Row(point: Vector) => OldVectors.fromML(point)
}
+ val instr = Instrumentation.create(this, rdd)
+ instr.logParams(featuresCol, predictionCol, k, maxIter, seed, minDivisibleClusterSize)
+
val bkm = new MLlibBisectingKMeans()
.setK($(k))
.setMaxIterations($(maxIter))
.setMinDivisibleClusterSize($(minDivisibleClusterSize))
.setSeed($(seed))
val parentModel = bkm.run(rdd)
- val model = new BisectingKMeansModel(uid, parentModel)
- copyValues(model.setParent(this))
+ val model = copyValues(new BisectingKMeansModel(uid, parentModel).setParent(this))
+ val summary = new BisectingKMeansSummary(
+ model.transform(dataset), $(predictionCol), $(featuresCol), $(k))
+ model.setSummary(summary)
+ val m = model.setSummary(summary)
+ instr.logSuccess(m)
+ m
}
@Since("2.0.0")
@@ -251,3 +283,41 @@ object BisectingKMeans extends DefaultParamsReadable[BisectingKMeans] {
@Since("2.0.0")
override def load(path: String): BisectingKMeans = super.load(path)
}
+
+
+/**
+ * :: Experimental ::
+ * Summary of BisectingKMeans.
+ *
+ * @param predictions [[DataFrame]] produced by [[BisectingKMeansModel.transform()]]
+ * @param predictionCol Name for column of predicted clusters in `predictions`
+ * @param featuresCol Name for column of features in `predictions`
+ * @param k Number of clusters
+ */
+@Since("2.1.0")
+@Experimental
+class BisectingKMeansSummary private[clustering] (
+ @Since("2.1.0") @transient val predictions: DataFrame,
+ @Since("2.1.0") val predictionCol: String,
+ @Since("2.1.0") val featuresCol: String,
+ @Since("2.1.0") val k: Int) extends Serializable {
+
+ /**
+ * Cluster centers of the transformed data.
+ */
+ @Since("2.1.0")
+ @transient lazy val cluster: DataFrame = predictions.select(predictionCol)
+
+ /**
+ * Size of (number of data points in) each cluster.
+ */
+ @Since("2.1.0")
+ lazy val clusterSizes: Array[Long] = {
+ val sizes = Array.fill[Long](k)(0)
+ cluster.groupBy(predictionCol).count().select(predictionCol, "count").collect().foreach {
+ case Row(cluster: Int, count: Long) => sizes(cluster) = count
+ }
+ sizes
+ }
+
+}