aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-02-23 15:42:58 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-23 15:42:58 -0800
commit8d29001dec5c3695721a76df3f70da50512ef28f (patch)
treedcb610ddff00188cf9898cce6d3eee029c44010b
parent15e30155631d52e35ab8522584027ab350e5acb3 (diff)
downloadspark-8d29001dec5c3695721a76df3f70da50512ef28f.tar.gz
spark-8d29001dec5c3695721a76df3f70da50512ef28f.tar.bz2
spark-8d29001dec5c3695721a76df3f70da50512ef28f.zip
[SPARK-13011] K-means wrapper in SparkR
https://issues.apache.org/jira/browse/SPARK-13011 Author: Xusen Yin <yinxusen@gmail.com> Closes #11124 from yinxusen/SPARK-13011.
-rw-r--r--R/pkg/NAMESPACE4
-rw-r--r--R/pkg/R/generics.R8
-rw-r--r--R/pkg/R/mllib.R74
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R28
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala45
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala52
6 files changed, 203 insertions, 8 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index f194a46303..6a3d63f43f 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -13,7 +13,9 @@ export("print.jobj")
# MLlib integration
exportMethods("glm",
"predict",
- "summary")
+ "summary",
+ "kmeans",
+ "fitted")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 2dba71abec..ab61bce03d 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1160,3 +1160,11 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") })
#' @rdname rbind
#' @export
setGeneric("rbind", signature = "...")
+
+#' @rdname kmeans
+#' @export
+setGeneric("kmeans")
+
+#' @rdname fitted
+#' @export
+setGeneric("fitted")
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 8d3b4388ae..346f33d7da 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -104,11 +104,11 @@ setMethod("predict", signature(object = "PipelineModel"),
setMethod("summary", signature(object = "PipelineModel"),
function(object, ...) {
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelName", object@model)
+ "getModelName", object@model)
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelFeatures", object@model)
+ "getModelFeatures", object@model)
coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelCoefficients", object@model)
+ "getModelCoefficients", object@model)
if (modelName == "LinearRegressionModel") {
devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelDevianceResiduals", object@model)
@@ -119,10 +119,76 @@ setMethod("summary", signature(object = "PipelineModel"),
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
rownames(coefficients) <- unlist(features)
return(list(devianceResiduals = devianceResiduals, coefficients = coefficients))
- } else {
+ } else if (modelName == "LogisticRegressionModel") {
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
+ } else if (modelName == "KMeansModel") {
+ modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getKMeansModelSize", object@model)
+ cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getKMeansCluster", object@model, "classes")
+ k <- unlist(modelSize)[1]
+ size <- unlist(modelSize)[-1]
+ coefficients <- t(matrix(coefficients, ncol = k))
+ colnames(coefficients) <- unlist(features)
+ rownames(coefficients) <- 1:k
+ return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
+ } else {
+ stop(paste("Unsupported model", modelName, sep = " "))
+ }
+ })
+
+#' Fit a k-means model
+#'
+#' Fit a k-means model, similarly to R's kmeans().
+#'
+#' @param x DataFrame for training
+#' @param centers Number of centers
+#' @param iter.max Maximum iteration number
+#' @param algorithm Algorithm choosen to fit the model
+#' @return A fitted k-means model
+#' @rdname kmeans
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- kmeans(x, centers = 2, algorithm="random")
+#'}
+setMethod("kmeans", signature(x = "DataFrame"),
+ function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
+ columnNames <- as.array(colnames(x))
+ algorithm <- match.arg(algorithm)
+ model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf,
+ algorithm, iter.max, centers, columnNames)
+ return(new("PipelineModel", model = model))
+ })
+
+#' Get fitted result from a model
+#'
+#' Get fitted result from a model, similarly to R's fitted().
+#'
+#' @param object A fitted MLlib model
+#' @return DataFrame containing fitted values
+#' @rdname fitted
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- kmeans(trainingData, 2)
+#' fitted.model <- fitted(model)
+#' showDF(fitted.model)
+#'}
+setMethod("fitted", signature(object = "PipelineModel"),
+ function(object, method = c("centers", "classes"), ...) {
+ modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelName", object@model)
+
+ if (modelName == "KMeansModel") {
+ method <- match.arg(method)
+ fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getKMeansCluster", object@model, method)
+ return(dataFrame(fittedResult))
+ } else {
+ stop(paste("Unsupported model", modelName, sep = " "))
}
})
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 08099dd96a..595512e0e0 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -113,3 +113,31 @@ test_that("summary works on base GLM models", {
baseSummary <- summary(baseModel)
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
})
+
+test_that("kmeans", {
+ newIris <- iris
+ newIris$Species <- NULL
+ training <- suppressWarnings(createDataFrame(sqlContext, newIris))
+
+ # Cache the DataFrame here to work around the bug SPARK-13178.
+ cache(training)
+ take(training, 1)
+
+ model <- kmeans(x = training, centers = 2)
+ sample <- take(select(predict(model, training), "prediction"), 1)
+ expect_equal(typeof(sample$prediction), "integer")
+ expect_equal(sample$prediction, 1)
+
+ # Test stats::kmeans is working
+ statsModel <- kmeans(x = newIris, centers = 2)
+ expect_equal(unique(statsModel$cluster), c(1, 2))
+
+ # Test fitted works on KMeans
+ fitted.model <- fitted(model)
+ expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1))
+
+ # Test summary works on KMeans
+ summary.model <- summary(model)
+ cluster <- summary.model$cluster
+ expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
+})
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index b2292e20e2..c6a3eac587 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
@@ -135,6 +136,26 @@ class KMeansModel private[ml] (
@Since("1.6.0")
override def write: MLWriter = new KMeansModel.KMeansModelWriter(this)
+
+ private var trainingSummary: Option[KMeansSummary] = None
+
+ private[clustering] def setSummary(summary: KMeansSummary): this.type = {
+ this.trainingSummary = Some(summary)
+ this
+ }
+
+ /**
+ * Gets summary of model on training set. An exception is
+ * thrown if `trainingSummary == None`.
+ */
+ @Since("2.0.0")
+ def summary: KMeansSummary = trainingSummary match {
+ case Some(summ) => summ
+ case None =>
+ throw new SparkException(
+ s"No training summary available for the ${this.getClass.getSimpleName}",
+ new NullPointerException())
+ }
}
@Since("1.6.0")
@@ -249,8 +270,9 @@ class KMeans @Since("1.5.0") (
.setSeed($(seed))
.setEpsilon($(tol))
val parentModel = algo.run(rdd)
- val model = new KMeansModel(uid, parentModel)
- copyValues(model.setParent(this))
+ val model = copyValues(new KMeansModel(uid, parentModel).setParent(this))
+ val summary = new KMeansSummary(model.transform(dataset), $(predictionCol), $(featuresCol))
+ model.setSummary(summary)
}
@Since("1.5.0")
@@ -266,3 +288,22 @@ object KMeans extends DefaultParamsReadable[KMeans] {
override def load(path: String): KMeans = super.load(path)
}
+class KMeansSummary private[clustering] (
+ @Since("2.0.0") @transient val predictions: DataFrame,
+ @Since("2.0.0") val predictionCol: String,
+ @Since("2.0.0") val featuresCol: String) extends Serializable {
+
+ /**
+ * Cluster centers of the transformed data.
+ */
+ @Since("2.0.0")
+ @transient lazy val cluster: DataFrame = predictions.select(predictionCol)
+
+ /**
+ * Size of each cluster.
+ */
+ @Since("2.0.0")
+ lazy val size: Array[Int] = cluster.map {
+ case Row(clusterIdx: Int) => (clusterIdx, 1)
+ }.reduceByKey(_ + _).collect().sortBy(_._1).map(_._2)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index 551e75dc0a..d23e4fc9d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -20,7 +20,8 @@ package org.apache.spark.ml.api.r
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
-import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
+import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.sql.DataFrame
@@ -51,6 +52,22 @@ private[r] object SparkRWrappers {
pipeline.fit(df)
}
+ def fitKMeans(
+ df: DataFrame,
+ initMode: String,
+ maxIter: Double,
+ k: Double,
+ columns: Array[String]): PipelineModel = {
+ val assembler = new VectorAssembler().setInputCols(columns)
+ val kMeans = new KMeans()
+ .setInitMode(initMode)
+ .setMaxIter(maxIter.toInt)
+ .setK(k.toInt)
+ .setFeaturesCol(assembler.getOutputCol)
+ val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
+ pipeline.fit(df)
+ }
+
def getModelCoefficients(model: PipelineModel): Array[Double] = {
model.stages.last match {
case m: LinearRegressionModel => {
@@ -72,6 +89,8 @@ private[r] object SparkRWrappers {
m.coefficients.toArray
}
}
+ case m: KMeansModel =>
+ m.clusterCenters.flatMap(_.toArray)
}
}
@@ -85,6 +104,31 @@ private[r] object SparkRWrappers {
}
}
+ def getKMeansModelSize(model: PipelineModel): Array[Int] = {
+ model.stages.last match {
+ case m: KMeansModel => Array(m.getK) ++ m.summary.size
+ case other => throw new UnsupportedOperationException(
+ s"KMeansModel required but ${other.getClass.getSimpleName} found.")
+ }
+ }
+
+ def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
+ model.stages.last match {
+ case m: KMeansModel =>
+ if (method == "centers") {
+ // Drop the assembled vector for easy-print to R side.
+ m.summary.predictions.drop(m.summary.featuresCol)
+ } else if (method == "classes") {
+ m.summary.cluster
+ } else {
+ throw new UnsupportedOperationException(
+ s"Method (centers or classes) required but $method found.")
+ }
+ case other => throw new UnsupportedOperationException(
+ s"KMeansModel required but ${other.getClass.getSimpleName} found.")
+ }
+ }
+
def getModelFeatures(model: PipelineModel): Array[String] = {
model.stages.last match {
case m: LinearRegressionModel =>
@@ -103,6 +147,10 @@ private[r] object SparkRWrappers {
} else {
attrs.attributes.get.map(_.name.get)
}
+ case m: KMeansModel =>
+ val attrs = AttributeGroup.fromStructField(
+ m.summary.predictions.schema(m.summary.featuresCol))
+ attrs.attributes.get.map(_.name.get)
}
}
@@ -112,6 +160,8 @@ private[r] object SparkRWrappers {
"LinearRegressionModel"
case m: LogisticRegressionModel =>
"LogisticRegressionModel"
+ case m: KMeansModel =>
+ "KMeansModel"
}
}
}