From 8d29001dec5c3695721a76df3f70da50512ef28f Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 23 Feb 2016 15:42:58 -0800 Subject: [SPARK-13011] K-means wrapper in SparkR https://issues.apache.org/jira/browse/SPARK-13011 Author: Xusen Yin Closes #11124 from yinxusen/SPARK-13011. --- R/pkg/NAMESPACE | 4 +- R/pkg/R/generics.R | 8 ++++ R/pkg/R/mllib.R | 74 ++++++++++++++++++++++++++++++++-- R/pkg/inst/tests/testthat/test_mllib.R | 28 +++++++++++++ 4 files changed, 109 insertions(+), 5 deletions(-) (limited to 'R/pkg') diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f194a46303..6a3d63f43f 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -13,7 +13,9 @@ export("print.jobj") # MLlib integration exportMethods("glm", "predict", - "summary") + "summary", + "kmeans", + "fitted") # Job group lifecycle management methods export("setJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 2dba71abec..ab61bce03d 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1160,3 +1160,11 @@ setGeneric("predict", function(object, ...) { standardGeneric("predict") }) #' @rdname rbind #' @export setGeneric("rbind", signature = "...") + +#' @rdname kmeans +#' @export +setGeneric("kmeans") + +#' @rdname fitted +#' @export +setGeneric("fitted") diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 8d3b4388ae..346f33d7da 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -104,11 +104,11 @@ setMethod("predict", signature(object = "PipelineModel"), setMethod("summary", signature(object = "PipelineModel"), function(object, ...) { modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelName", object@model) + "getModelName", object@model) features <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelFeatures", object@model) + "getModelFeatures", object@model) coefficients <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", - "getModelCoefficients", object@model) + "getModelCoefficients", object@model) if (modelName == "LinearRegressionModel") { devianceResiduals <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "getModelDevianceResiduals", object@model) @@ -119,10 +119,76 @@ setMethod("summary", signature(object = "PipelineModel"), colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") rownames(coefficients) <- unlist(features) return(list(devianceResiduals = devianceResiduals, coefficients = coefficients)) - } else { + } else if (modelName == "LogisticRegressionModel") { coefficients <- as.matrix(unlist(coefficients)) colnames(coefficients) <- c("Estimate") rownames(coefficients) <- unlist(features) return(list(coefficients = coefficients)) + } else if (modelName == "KMeansModel") { + modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansModelSize", object@model) + cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansCluster", object@model, "classes") + k <- unlist(modelSize)[1] + size <- unlist(modelSize)[-1] + coefficients <- t(matrix(coefficients, ncol = k)) + colnames(coefficients) <- unlist(features) + rownames(coefficients) <- 1:k + return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster))) + } else { + stop(paste("Unsupported model", modelName, sep = " ")) + } + }) + +#' Fit a k-means model +#' +#' Fit a k-means model, similarly to R's kmeans(). +#' +#' @param x DataFrame for training +#' @param centers Number of centers +#' @param iter.max Maximum iteration number +#' @param algorithm Algorithm choosen to fit the model +#' @return A fitted k-means model +#' @rdname kmeans +#' @export +#' @examples +#'\dontrun{ +#' model <- kmeans(x, centers = 2, algorithm="random") +#'} +setMethod("kmeans", signature(x = "DataFrame"), + function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) { + columnNames <- as.array(colnames(x)) + algorithm <- match.arg(algorithm) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf, + algorithm, iter.max, centers, columnNames) + return(new("PipelineModel", model = model)) + }) + +#' Get fitted result from a model +#' +#' Get fitted result from a model, similarly to R's fitted(). +#' +#' @param object A fitted MLlib model +#' @return DataFrame containing fitted values +#' @rdname fitted +#' @export +#' @examples +#'\dontrun{ +#' model <- kmeans(trainingData, 2) +#' fitted.model <- fitted(model) +#' showDF(fitted.model) +#'} +setMethod("fitted", signature(object = "PipelineModel"), + function(object, method = c("centers", "classes"), ...) { + modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getModelName", object@model) + + if (modelName == "KMeansModel") { + method <- match.arg(method) + fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "getKMeansCluster", object@model, method) + return(dataFrame(fittedResult)) + } else { + stop(paste("Unsupported model", modelName, sep = " ")) } }) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 08099dd96a..595512e0e0 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -113,3 +113,31 @@ test_that("summary works on base GLM models", { baseSummary <- summary(baseModel) expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) }) + +test_that("kmeans", { + newIris <- iris + newIris$Species <- NULL + training <- suppressWarnings(createDataFrame(sqlContext, newIris)) + + # Cache the DataFrame here to work around the bug SPARK-13178. + cache(training) + take(training, 1) + + model <- kmeans(x = training, centers = 2) + sample <- take(select(predict(model, training), "prediction"), 1) + expect_equal(typeof(sample$prediction), "integer") + expect_equal(sample$prediction, 1) + + # Test stats::kmeans is working + statsModel <- kmeans(x = newIris, centers = 2) + expect_equal(unique(statsModel$cluster), c(1, 2)) + + # Test fitted works on KMeans + fitted.model <- fitted(model) + expect_equal(sort(collect(distinct(select(fitted.model, "prediction")))$prediction), c(0, 1)) + + # Test summary works on KMeans + summary.model <- summary(model) + cluster <- summary.model$cluster + expect_equal(sort(collect(distinct(select(cluster, "prediction")))$prediction), c(0, 1)) +}) -- cgit v1.2.3