aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-02-23 15:42:58 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-23 15:42:58 -0800
commit8d29001dec5c3695721a76df3f70da50512ef28f (patch)
treedcb610ddff00188cf9898cce6d3eee029c44010b /R
parent15e30155631d52e35ab8522584027ab350e5acb3 (diff)
downloadspark-8d29001dec5c3695721a76df3f70da50512ef28f.tar.gz
spark-8d29001dec5c3695721a76df3f70da50512ef28f.tar.bz2
spark-8d29001dec5c3695721a76df3f70da50512ef28f.zip
[SPARK-13011] K-means wrapper in SparkR
https://issues.apache.org/jira/browse/SPARK-13011 Author: Xusen Yin <yinxusen@gmail.com> Closes #11124 from yinxusen/SPARK-13011.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/NAMESPACE4
-rw-r--r--R/pkg/R/generics.R8
-rw-r--r--R/pkg/R/mllib.R74
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R28
4 files changed, 109 insertions, 5 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index f194a46303..6a3d63f43f 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -13,7 +13,9 @@ export("print.jobj")
# MLlib integration
exportMethods("glm",
"predict",
- "summary")
+ "summary",
+ "kmeans",
+ "fitted")
# Job group lifecycle management methods
export("setJobGroup",
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 2dba71abec..ab61bce03d 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1160,3 +1160,11 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") })
#' @rdname rbind
#' @export
setGeneric("rbind", signature = "...")
+
+#' @rdname kmeans
+#' @export
+setGeneric("kmeans")
+
+#' @rdname fitted
+#' @export
+setGeneric("fitted")
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 8d3b4388ae..346f33d7da 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -104,11 +104,11 @@ setMethod("predict", signature(object = "PipelineModel"),
setMethod("summary", signature(object = "PipelineModel"),
function(object, ...) {
modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelName", object@model)
+ "getModelName", object@model)
features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelFeatures", object@model)
+ "getModelFeatures", object@model)
coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelCoefficients", object@model)
+ "getModelCoefficients", object@model)
if (modelName == "LinearRegressionModel") {
devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
"getModelDevianceResiduals", object@model)
@@ -119,10 +119,76 @@ setMethod("summary", signature(object = "PipelineModel"),
colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)")
rownames(coefficients) <- unlist(features)
return(list(devianceResiduals = devianceResiduals, coefficients = coefficients))
- } else {
+ } else if (modelName == "LogisticRegressionModel") {
coefficients <- as.matrix(unlist(coefficients))
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
+ } else if (modelName == "KMeansModel") {
+ modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getKMeansModelSize", object@model)
+ cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getKMeansCluster", object@model, "classes")
+ k <- unlist(modelSize)[1]
+ size <- unlist(modelSize)[-1]
+ coefficients <- t(matrix(coefficients, ncol = k))
+ colnames(coefficients) <- unlist(features)
+ rownames(coefficients) <- 1:k
+ return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
+ } else {
+ stop(paste("Unsupported model", modelName, sep = " "))
+ }
+ })
+
+#' Fit a k-means model
+#'
+#' Fit a k-means model, similarly to R's kmeans().
+#'
+#' @param x DataFrame for training
+#' @param centers Number of centers
+#' @param iter.max Maximum iteration number
+#' @param algorithm Algorithm choosen to fit the model
+#' @return A fitted k-means model
+#' @rdname kmeans
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- kmeans(x, centers = 2, algorithm="random")
+#'}
+setMethod("kmeans", signature(x = "DataFrame"),
+ function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
+ columnNames <- as.array(colnames(x))
+ algorithm <- match.arg(algorithm)
+ model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf,
+ algorithm, iter.max, centers, columnNames)
+ return(new("PipelineModel", model = model))
+ })
+
+#' Get fitted result from a model
+#'
+#' Get fitted result from a model, similarly to R's fitted().
+#'
+#' @param object A fitted MLlib model
+#' @return DataFrame containing fitted values
+#' @rdname fitted
+#' @export
+#' @examples
+#'\dontrun{
+#' model <- kmeans(trainingData, 2)
+#' fitted.model <- fitted(model)
+#' showDF(fitted.model)
+#'}
+setMethod("fitted", signature(object = "PipelineModel"),
+ function(object, method = c("centers", "classes"), ...) {
+ modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getModelName", object@model)
+
+ if (modelName == "KMeansModel") {
+ method <- match.arg(method)
+ fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
+ "getKMeansCluster", object@model, method)
+ return(dataFrame(fittedResult))
+ } else {
+ stop(paste("Unsupported model", modelName, sep = " "))
}
})
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 08099dd96a..595512e0e0 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -113,3 +113,31 @@ test_that("summary works on base GLM models", {
baseSummary <- summary(baseModel)
expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4)
})
+
+test_that("kmeans", {
+ newIris <- iris
+ newIris$Species <- NULL
+ training <- suppressWarnings(createDataFrame(sqlContext, newIris))
+
+ # Cache the DataFrame here to work around the bug SPARK-13178.
+ cache(training)
+ take(training, 1)
+
+ model <- kmeans(x = training, centers = 2)
+ sample <- take(select(predict(model, training), "prediction"), 1)
+ expect_equal(typeof(sample$prediction), "integer")
+ expect_equal(sample$prediction, 1)
+
+ # Test stats::kmeans is working
+ statsModel <- kmeans(x = newIris, centers = 2)
+ expect_equal(unique(statsModel$cluster), c(1, 2))
+
+ # Test fitted works on KMeans
+ fitted.model <- fitted(model)
+ expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1))
+
+ # Test summary works on KMeans
+ summary.model <- summary(model)
+ cluster <- summary.model$cluster
+ expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1))
+})