diff options
author | Xusen Yin <yinxusen@gmail.com> | 2016-02-23 15:42:58 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-02-23 15:42:58 -0800 |
commit | 8d29001dec5c3695721a76df3f70da50512ef28f (patch) | |
tree | dcb610ddff00188cf9898cce6d3eee029c44010b /mllib/src | |
parent | 15e30155631d52e35ab8522584027ab350e5acb3 (diff) | |
download | spark-8d29001dec5c3695721a76df3f70da50512ef28f.tar.gz spark-8d29001dec5c3695721a76df3f70da50512ef28f.tar.bz2 spark-8d29001dec5c3695721a76df3f70da50512ef28f.zip |
[SPARK-13011] K-means wrapper in SparkR
https://issues.apache.org/jira/browse/SPARK-13011
Author: Xusen Yin <yinxusen@gmail.com>
Closes #11124 from yinxusen/SPARK-13011.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala | 45 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala | 52 |
2 files changed, 94 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index b2292e20e2..c6a3eac587 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} @@ -135,6 +136,26 @@ class KMeansModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + + private var trainingSummary: Option[KMeansSummary] = None + + private[clustering] def setSummary(summary: KMeansSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: KMeansSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + s"No training summary available for the ${this.getClass.getSimpleName}", + new NullPointerException()) + } } @Since("1.6.0") @@ -249,8 +270,9 @@ class KMeans @Since("1.5.0") ( .setSeed($(seed)) .setEpsilon($(tol)) val parentModel = algo.run(rdd) - val model = new KMeansModel(uid, parentModel) - copyValues(model.setParent(this)) + val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) + val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol)) + model.setSummary(summary) } @Since("1.5.0") @@ -266,3 +288,22 @@ object KMeans extends DefaultParamsReadable[KMeans] { override def load(path: String): KMeans = super.load(path) } +class KMeansSummary private[clustering] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val featuresCol: String) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.0.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Size of each cluster. + */ + @Since("2.0.0") + lazy val size: Array[Int] = cluster.map { + case Row(clusterIdx: Int) => (clusterIdx, 1) + }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 551e75dc0a..d23e4fc9d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -20,7 +20,8 @@ package org.apache.spark.ml.api.r import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.clustering.{KMeans, KMeansModel} +import org.apache.spark.ml.feature.{RFormula, VectorAssembler} import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.sql.DataFrame @@ -51,6 +52,22 @@ private[r] object SparkRWrappers { pipeline.fit(df) } + def fitKMeans( + df: DataFrame, + initMode: String, + maxIter: Double, + k: Double, + columns: Array[String]): PipelineModel = { + val assembler = new VectorAssembler().setInputCols(columns) + val kMeans = new KMeans() + .setInitMode(initMode) + .setMaxIter(maxIter.toInt) + .setK(k.toInt) + .setFeaturesCol(assembler.getOutputCol) + val pipeline = new Pipeline().setStages(Array(assembler, kMeans)) + pipeline.fit(df) + } + def getModelCoefficients(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => { @@ -72,6 +89,8 @@ private[r] object SparkRWrappers { m.coefficients.toArray } } + case m: KMeansModel => + m.clusterCenters.flatMap(_.toArray) } } @@ -85,6 +104,31 @@ private[r] object SparkRWrappers { } } + def getKMeansModelSize(model: PipelineModel): Array[Int] = { + model.stages.last match { + case m: KMeansModel => Array(m.getK) ++ m.summary.size + case other => throw new UnsupportedOperationException( + s"KMeansModel required but ${other.getClass.getSimpleName} found.") + } + } + + def getKMeansCluster(model: PipelineModel, method: String): DataFrame = { + model.stages.last match { + case m: KMeansModel => + if (method == "centers") { + // Drop the assembled vector for easy-print to R side. + m.summary.predictions.drop(m.summary.featuresCol) + } else if (method == "classes") { + m.summary.cluster + } else { + throw new UnsupportedOperationException( + s"Method (centers or classes) required but $method found.") + } + case other => throw new UnsupportedOperationException( + s"KMeansModel required but ${other.getClass.getSimpleName} found.") + } + } + def getModelFeatures(model: PipelineModel): Array[String] = { model.stages.last match { case m: LinearRegressionModel => @@ -103,6 +147,10 @@ private[r] object SparkRWrappers { } else { attrs.attributes.get.map(_.name.get) } + case m: KMeansModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + attrs.attributes.get.map(_.name.get) } } @@ -112,6 +160,8 @@ private[r] object SparkRWrappers { "LinearRegressionModel" case m: LogisticRegressionModel => "LogisticRegressionModel" + case m: KMeansModel => + "KMeansModel" } } } |