From 8d29001dec5c3695721a76df3f70da50512ef28f Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 23 Feb 2016 15:42:58 -0800 Subject: [SPARK-13011] K-means wrapper in SparkR https://issues.apache.org/jira/browse/SPARK-13011 Author: Xusen Yin Closes #11124 from yinxusen/SPARK-13011. --- R/pkg/NAMESPACE | 4 +- R/pkg/R/generics.R | 8 +++ R/pkg/R/mllib.R | 74 ++++++++++++++++++++-- R/pkg/inst/tests/testthat/test_mllib.R | 28 ++++++++ .../org/apache/spark/ml/clustering/KMeans.scala | 45 ++++++++++++- .../org/apache/spark/ml/r/SparkRWrappers.scala | 52 ++++++++++++++- 6 files changed, 203 insertions(+), 8 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f194a46303..6a3d63f43f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -13,7 +13,9 @@ export("print.jobj") # MLlib integration exportMethods("glm", "predict", - "summary") + "summary", + "kmeans", + "fitted") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 2dba71abec..ab61bce03d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1160,3 +1160,11 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind #' @export setGeneric("rbind", signature = "...") + +#' @rdname kmeans +#' @export +setGeneric("kmeans") + +#' @rdname fitted +#' @export +setGeneric("fitted") diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 8d3b4388ae..346f33d7da 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -104,11 +104,11 @@ setMethod("predict", signature(object = "PipelineModel"), setMethod("summary", signature(object = "PipelineModel"), function(object, ...) { modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", object@model) + "getModelName", object@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", object@model) + "getModelFeatures", object@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", object@model) + "getModelCoefficients", object@model) if (modelName == "LinearRegressionModel") { devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelDevianceResiduals", object@model) @@ -119,10 +119,76 @@ setMethod("summary", signature(object = "PipelineModel"), colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") rownames(coefficients) <- unlist(features) return(list(devianceResiduals = devianceResiduals, coefficients = coefficients)) - } else { + } else if (modelName == "LogisticRegressionModel") { coefficients <- as.matrix(unlist(coefficients)) colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) return(list(coefficients = coefficients)) + } else if (modelName == "KMeansModel") { + modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansModelSize", object@model) + cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansCluster", object@model, "classes") + k <- unlist(modelSize)[1] + size <- unlist(modelSize)[-1] + coefficients <- t(matrix(coefficients, ncol = k)) + colnames(coefficients) <- unlist(features) + rownames(coefficients) <- 1:k + return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) + } else { + stop(paste("Unsupported model", modelName, sep = " ")) + } + }) + +#' Fit a k-means model +#' +#' Fit a k-means model, similarly to R's kmeans(). +#' +#' @param x DataFrame for training +#' @param centers Number of centers +#' @param iter.max Maximum iteration number +#' @param algorithm Algorithm choosen to fit the model +#' @return A fitted k-means model +#' @rdname kmeans +#' @export +#' @examples +#'\dontrun{ +#' model <- kmeans(x, centers = 2, algorithm="random") +#'} +setMethod("kmeans", signature(x = "DataFrame"), + function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) { + columnNames <- as.array(colnames(x)) + algorithm <- match.arg(algorithm) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf, + algorithm, iter.max, centers, columnNames) + return(new("PipelineModel", model = model)) + }) + +#' Get fitted result from a model +#' +#' Get fitted result from a model, similarly to R's fitted(). +#' +#' @param object A fitted MLlib model +#' @return DataFrame containing fitted values +#' @rdname fitted +#' @export +#' @examples +#'\dontrun{ +#' model <- kmeans(trainingData, 2) +#' fitted.model <- fitted(model) +#' showDF(fitted.model) +#'} +setMethod("fitted", signature(object = "PipelineModel"), + function(object, method = c("centers", "classes"), ...) { + modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelName", object@model) + + if (modelName == "KMeansModel") { + method <- match.arg(method) + fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansCluster", object@model, method) + return(dataFrame(fittedResult)) + } else { + stop(paste("Unsupported model", modelName, sep = " ")) } }) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 08099dd96a..595512e0e0 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -113,3 +113,31 @@ test_that("summary works on base GLM models", { baseSummary <- summary(baseModel) expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) }) + +test_that("kmeans", { + newIris <- iris + newIris$Species <- NULL + training <- suppressWarnings(createDataFrame(sqlContext, newIris)) + + # Cache the DataFrame here to work around the bug SPARK-13178. + cache(training) + take(training, 1) + + model <- kmeans(x = training, centers = 2) + sample <- take(select(predict(model, training), "prediction"), 1) + expect_equal(typeof(sample$prediction), "integer") + expect_equal(sample$prediction, 1) + + # Test stats::kmeans is working + statsModel <- kmeans(x = newIris, centers = 2) + expect_equal(unique(statsModel$cluster), c(1, 2)) + + # Test fitted works on KMeans + fitted.model <- fitted(model) + expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1)) + + # Test summary works on KMeans + summary.model <- summary(model) + cluster <- summary.model$cluster + expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) +}) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index b2292e20e2..c6a3eac587 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} @@ -135,6 +136,26 @@ class KMeansModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new KMeansModel.KMeansModelWriter(this) + + private var trainingSummary: Option[KMeansSummary] = None + + private[clustering] def setSummary(summary: KMeansSummary): this.type = { + this.trainingSummary = Some(summary) + this + } + + /** + * Gets summary of model on training set. An exception is + * thrown if `trainingSummary == None`. + */ + @Since("2.0.0") + def summary: KMeansSummary = trainingSummary match { + case Some(summ) => summ + case None => + throw new SparkException( + s"No training summary available for the ${this.getClass.getSimpleName}", + new NullPointerException()) + } } @Since("1.6.0") @@ -249,8 +270,9 @@ class KMeans @Since("1.5.0") ( .setSeed($(seed)) .setEpsilon($(tol)) val parentModel = algo.run(rdd) - val model = new KMeansModel(uid, parentModel) - copyValues(model.setParent(this)) + val model = copyValues(new KMeansModel(uid, parentModel).setParent(this)) + val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol)) + model.setSummary(summary) } @Since("1.5.0") @@ -266,3 +288,22 @@ object KMeans extends DefaultParamsReadable[KMeans] { override def load(path: String): KMeans = super.load(path) } +class KMeansSummary private[clustering] ( + @Since("2.0.0") @transient val predictions: DataFrame, + @Since("2.0.0") val predictionCol: String, + @Since("2.0.0") val featuresCol: String) extends Serializable { + + /** + * Cluster centers of the transformed data. + */ + @Since("2.0.0") + @transient lazy val cluster: DataFrame = predictions.select(predictionCol) + + /** + * Size of each cluster. + */ + @Since("2.0.0") + lazy val size: Array[Int] = cluster.map { + case Row(clusterIdx: Int) => (clusterIdx, 1) + }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2) +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 551e75dc0a..d23e4fc9d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -20,7 +20,8 @@ package org.apache.spark.ml.api.r import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} -import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.clustering.{KMeans, KMeansModel} +import org.apache.spark.ml.feature.{RFormula, VectorAssembler} import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.sql.DataFrame @@ -51,6 +52,22 @@ private[r] object SparkRWrappers { pipeline.fit(df) } + def fitKMeans( + df: DataFrame, + initMode: String, + maxIter: Double, + k: Double, + columns: Array[String]): PipelineModel = { + val assembler = new VectorAssembler().setInputCols(columns) + val kMeans = new KMeans() + .setInitMode(initMode) + .setMaxIter(maxIter.toInt) + .setK(k.toInt) + .setFeaturesCol(assembler.getOutputCol) + val pipeline = new Pipeline().setStages(Array(assembler, kMeans)) + pipeline.fit(df) + } + def getModelCoefficients(model: PipelineModel): Array[Double] = { model.stages.last match { case m: LinearRegressionModel => { @@ -72,6 +89,8 @@ private[r] object SparkRWrappers { m.coefficients.toArray } } + case m: KMeansModel => + m.clusterCenters.flatMap(_.toArray) } } @@ -85,6 +104,31 @@ private[r] object SparkRWrappers { } } + def getKMeansModelSize(model: PipelineModel): Array[Int] = { + model.stages.last match { + case m: KMeansModel => Array(m.getK) ++ m.summary.size + case other => throw new UnsupportedOperationException( + s"KMeansModel required but ${other.getClass.getSimpleName} found.") + } + } + + def getKMeansCluster(model: PipelineModel, method: String): DataFrame = { + model.stages.last match { + case m: KMeansModel => + if (method == "centers") { + // Drop the assembled vector for easy-print to R side. + m.summary.predictions.drop(m.summary.featuresCol) + } else if (method == "classes") { + m.summary.cluster + } else { + throw new UnsupportedOperationException( + s"Method (centers or classes) required but $method found.") + } + case other => throw new UnsupportedOperationException( + s"KMeansModel required but ${other.getClass.getSimpleName} found.") + } + } + def getModelFeatures(model: PipelineModel): Array[String] = { model.stages.last match { case m: LinearRegressionModel => @@ -103,6 +147,10 @@ private[r] object SparkRWrappers { } else { attrs.attributes.get.map(_.name.get) } + case m: KMeansModel => + val attrs = AttributeGroup.fromStructField( + m.summary.predictions.schema(m.summary.featuresCol)) + attrs.attributes.get.map(_.name.get) } } @@ -112,6 +160,8 @@ private[r] object SparkRWrappers { "LinearRegressionModel" case m: LogisticRegressionModel => "LogisticRegressionModel" + case m: KMeansModel => + "KMeansModel" } } } -- cgit v1.2.3