aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-03-31 23:49:58 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-31 23:49:58 -0700
commit22249afb4a932a82ff1f7a3befea9fda5a60a3f4 (patch)
tree107b6166b9f3e1ec51c5d8681c10af7ec57bc836 /R/pkg
parent26867ebc67edab97376c5d8fee76df294359e461 (diff)
downloadspark-22249afb4a932a82ff1f7a3befea9fda5a60a3f4.tar.gz
spark-22249afb4a932a82ff1f7a3befea9fda5a60a3f4.tar.bz2
spark-22249afb4a932a82ff1f7a3befea9fda5a60a3f4.zip
[SPARK-14303][ML][SPARKR] Define and use KMeansWrapper for SparkR::kmeans
## What changes were proposed in this pull request? Define and use ```KMeansWrapper``` for ```SparkR::kmeans```. It's only the code refactor for the original ```KMeans``` wrapper. ## How was this patch tested? Existing tests. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #12039 from yanboliang/spark-14059.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/R/mllib.R91
1 files changed, 62 insertions, 29 deletions
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 33654d5216..f3152cc232 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -32,6 +32,11 @@ setClass("NaiveBayesModel", representation(jobj = "jobj"))
#' @export
setClass("AFTSurvivalRegressionModel", representation(jobj = "jobj"))
+#' @title S4 class that represents a KMeansModel
+#' @param jobj a Java object reference to the backing Scala KMeansModel
+#' @export
+setClass("KMeansModel", representation(jobj = "jobj"))
+
#' Fits a generalized linear model
#'
#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package.
@@ -154,17 +159,6 @@ setMethod("summary", signature(object = "PipelineModel"),
colnames(coefficients) <- c("Estimate")
rownames(coefficients) <- unlist(features)
return(list(coefficients = coefficients))
- } else if (modelName == "KMeansModel") {
- modelSize <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getKMeansModelSize", object@model)
- cluster <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getKMeansCluster", object@model, "classes")
- k <- unlist(modelSize)[1]
- size <- unlist(modelSize)[-1]
- coefficients <- t(matrix(coefficients, ncol = k))
- colnames(coefficients) <- unlist(features)
- rownames(coefficients) <- 1:k
- return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
} else {
stop(paste("Unsupported model", modelName, sep = " "))
}
@@ -213,21 +207,21 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
#' @examples
#' \dontrun{
#' model <- kmeans(x, centers = 2, algorithm="random")
-#'}
+#' }
setMethod("kmeans", signature(x = "DataFrame"),
function(x, centers, iter.max = 10, algorithm = c("random", "k-means||")) {
columnNames <- as.array(colnames(x))
algorithm <- match.arg(algorithm)
- model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", "fitKMeans", x@sdf,
- algorithm, iter.max, centers, columnNames)
- return(new("PipelineModel", model = model))
+ jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", x@sdf,
+ centers, iter.max, algorithm, columnNames)
+ return(new("KMeansModel", jobj = jobj))
})
-#' Get fitted result from a model
+#' Get fitted result from a k-means model
#'
-#' Get fitted result from a model, similarly to R's fitted().
+#' Get fitted result from a k-means model, similarly to R's fitted().
#'
-#' @param object A fitted MLlib model
+#' @param object A fitted k-means model
#' @return DataFrame containing fitted values
#' @rdname fitted
#' @export
@@ -237,19 +231,58 @@ setMethod("kmeans", signature(x = "DataFrame"),
#' fitted.model <- fitted(model)
#' showDF(fitted.model)
#'}
-setMethod("fitted", signature(object = "PipelineModel"),
+setMethod("fitted", signature(object = "KMeansModel"),
function(object, method = c("centers", "classes"), ...) {
- modelName <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getModelName", object@model)
+ method <- match.arg(method)
+ return(dataFrame(callJMethod(object@jobj, "fitted", method)))
+ })
- if (modelName == "KMeansModel") {
- method <- match.arg(method)
- fittedResult <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers",
- "getKMeansCluster", object@model, method)
- return(dataFrame(fittedResult))
- } else {
- stop(paste("Unsupported model", modelName, sep = " "))
- }
+#' Get the summary of a k-means model
+#'
+#' Returns the summary of a k-means model produced by kmeans(),
+#' similarly to R's summary().
+#'
+#' @param object a fitted k-means model
+#' @return the model's coefficients, size and cluster
+#' @rdname summary
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- kmeans(trainingData, 2)
+#' summary(model)
+#' }
+setMethod("summary", signature(object = "KMeansModel"),
+ function(object, ...) {
+ jobj <- object@jobj
+ features <- callJMethod(jobj, "features")
+ coefficients <- callJMethod(jobj, "coefficients")
+ cluster <- callJMethod(jobj, "cluster")
+ k <- callJMethod(jobj, "k")
+ size <- callJMethod(jobj, "size")
+ coefficients <- t(matrix(coefficients, ncol = k))
+ colnames(coefficients) <- unlist(features)
+ rownames(coefficients) <- 1:k
+ return(list(coefficients = coefficients, size = size, cluster = dataFrame(cluster)))
+ })
+
+#' Make predictions from a k-means model
+#'
+#' Make predictions from a model produced by kmeans().
+#'
+#' @param object A fitted k-means model
+#' @param newData DataFrame for testing
+#' @return DataFrame containing predicted labels in a column named "prediction"
+#' @rdname predict
+#' @export
+#' @examples
+#' \dontrun{
+#' model <- kmeans(trainingData, 2)
+#' predicted <- predict(model, testData)
+#' showDF(predicted)
+#' }
+setMethod("predict", signature(object = "KMeansModel"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
})
#' Fit a Bernoulli naive Bayes model