aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-02-23 15:42:58 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-23 15:42:58 -0800
commit8d29001dec5c3695721a76df3f70da50512ef28f (patch)
treedcb610ddff00188cf9898cce6d3eee029c44010b /mllib
parent15e30155631d52e35ab8522584027ab350e5acb3 (diff)
downloadspark-8d29001dec5c3695721a76df3f70da50512ef28f.tar.gz
spark-8d29001dec5c3695721a76df3f70da50512ef28f.tar.bz2
spark-8d29001dec5c3695721a76df3f70da50512ef28f.zip
[SPARK-13011] K-means wrapper in SparkR
https://issues.apache.org/jira/browse/SPARK-13011 Author: Xusen Yin <yinxusen@gmail.com> Closes #11124 from yinxusen/SPARK-13011.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala45
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala52
2 files changed, 94 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index b2292e20e2..c6a3eac587 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
@@ -135,6 +136,26 @@ class KMeansModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
+
+ private var trainingSummary: Option[KMeansSummary] = None
+
+ private[clustering] def setSummary(summary: KMeansSummary): this.type = {
+ this.trainingSummary = Some(summary)
+ this
+ }
+
+ /**
+ * Gets summary of model on training set. An exception is
+ * thrown if `trainingSummary == None`.
+ */
+ @Since("2.0.0")
+ def summary: KMeansSummary = trainingSummary match {
+ case Some(summ) => summ
+ case None =>
+ throw new SparkException(
+ s"No training summary available for the ${this.getClass.getSimpleName}",
+ new NullPointerException())
+ }
}
@Since("1.6.0")
@@ -249,8 +270,9 @@ class KMeans @Since("1.5.0") (
.setSeed($(seed))
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
- val model = new KMeansModel(uid, parentModel)
- copyValues(model.setParent(this))
+ val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
+ val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol))
+ model.setSummary(summary)
}
@Since("1.5.0")
@@ -266,3 +288,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
override def load(path: String): KMeans = super.load(path)
}
+class KMeansSummary private[clustering] (
+ @Since("2.0.0") @transient val predictions: DataFrame,
+ @Since("2.0.0") val predictionCol: String,
+ @Since("2.0.0") val featuresCol: String) extends Serializable {
+
+ /**
+ * Cluster centers of the transformed data.
+ */
+ @Since("2.0.0")
+ @transient lazy val cluster: DataFrame = predictions.select(predictionCol)
+
+ /**
+ * Size of each cluster.
+ */
+ @Since("2.0.0")
+ lazy val size: Array[Int] = cluster.map {
+ case Row(clusterIdx: Int) => (clusterIdx, 1)
+ }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 551e75dc0a..d23e4fc9d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -20,7 +20,8 @@ package org.apache.spark.ml.api.r
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
-import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
+import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.sql.DataFrame
@@ -51,6 +52,22 @@ private[r] object SparkRWrappers {
pipeline.fit(df)
}
+ def fitKMeans(
+ df: DataFrame,
+ initMode: String,
+ maxIter: Double,
+ k: Double,
+ columns: Array[String]): PipelineModel = {
+ val assembler = new VectorAssembler().setInputCols(columns)
+ val kMeans = new KMeans()
+ .setInitMode(initMode)
+ .setMaxIter(maxIter.toInt)
+ .setK(k.toInt)
+ .setFeaturesCol(assembler.getOutputCol)
+ val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
+ pipeline.fit(df)
+ }
+
def getModelCoefficients(model: PipelineModel): Array[Double] = {
model.stages.last match {
case m: LinearRegressionModel => {
@@ -72,6 +89,8 @@ private[r] object SparkRWrappers {
m.coefficients.toArray
}
}
+ case m: KMeansModel =>
+ m.clusterCenters.flatMap(_.toArray)
}
}
@@ -85,6 +104,31 @@ private[r] object SparkRWrappers {
}
}
+ def getKMeansModelSize(model: PipelineModel): Array[Int] = {
+ model.stages.last match {
+ case m: KMeansModel => Array(m.getK) ++ m.summary.size
+ case other => throw new UnsupportedOperationException(
+ s"KMeansModel required but ${other.getClass.getSimpleName} found.")
+ }
+ }
+
+ def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
+ model.stages.last match {
+ case m: KMeansModel =>
+ if (method == "centers") {
+ // Drop the assembled vector for easy-print to R side.
+ m.summary.predictions.drop(m.summary.featuresCol)
+ } else if (method == "classes") {
+ m.summary.cluster
+ } else {
+ throw new UnsupportedOperationException(
+ s"Method (centers or classes) required but $method found.")
+ }
+ case other => throw new UnsupportedOperationException(
+ s"KMeansModel required but ${other.getClass.getSimpleName} found.")
+ }
+ }
+
def getModelFeatures(model: PipelineModel): Array[String] = {
model.stages.last match {
case m: LinearRegressionModel =>
@@ -103,6 +147,10 @@ private[r] object SparkRWrappers {
} else {
attrs.attributes.get.map(_.name.get)
}
+ case m: KMeansModel =>
+ val attrs = AttributeGroup.fromStructField(
+ m.summary.predictions.schema(m.summary.featuresCol))
+ attrs.attributes.get.map(_.name.get)
}
}
@@ -112,6 +160,8 @@ private[r] object SparkRWrappers {
"LinearRegressionModel"
case m: LogisticRegressionModel =>
"LogisticRegressionModel"
+ case m: KMeansModel =>
+ "KMeansModel"
}
}
}