aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2017-01-12 22:27:57 -0800
committerYanbo Liang <ybliang8@gmail.com>2017-01-12 22:27:57 -0800
commit7f24a0b6c32c56a38cf879d953bbd523922ab9c9 (patch)
treed60ea1d9a8fcf309fb5c938452ac7018fbc5dd38
parent3356b8b6a9184fcab8d0fe993f3545c3beaa4d99 (diff)
downloadspark-7f24a0b6c32c56a38cf879d953bbd523922ab9c9.tar.gz
spark-7f24a0b6c32c56a38cf879d953bbd523922ab9c9.tar.bz2
spark-7f24a0b6c32c56a38cf879d953bbd523922ab9c9.zip
[SPARK-19142][SPARKR] spark.kmeans should take seed, initSteps, and tol as parameters
## What changes were proposed in this pull request? spark.kmeans doesn't have interface to set initSteps, seed and tol. As Spark Kmeans algorithm doesn't take the same set of parameters as R kmeans, we should maintain a different interface in spark.kmeans. Add missing parameters and corresponding document. Modified existing unit tests to take additional parameters. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #16523 from wangmiao1981/kmeans.
-rw-r--r--R/pkg/R/mllib_clustering.R13
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_clustering.R20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala9
3 files changed, 39 insertions, 3 deletions
diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R
index c443588387..ca5182d527 100644
--- a/R/pkg/R/mllib_clustering.R
+++ b/R/pkg/R/mllib_clustering.R
@@ -175,6 +175,10 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact
#' @param k number of centers.
#' @param maxIter maximum iteration number.
#' @param initMode the initialization algorithm choosen to fit the model.
+#' @param seed the random seed for cluster initialization.
+#' @param initSteps the number of steps for the k-means|| initialization mode.
+#' This is an advanced setting, the default of 2 is almost always enough. Must be > 0.
+#' @param tol convergence tolerance of iterations.
#' @param ... additional argument(s) passed to the method.
#' @return \code{spark.kmeans} returns a fitted k-means model.
#' @rdname spark.kmeans
@@ -204,11 +208,16 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact
#' @note spark.kmeans since 2.0.0
#' @seealso \link{predict}, \link{read.ml}, \link{write.ml}
setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"),
- function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) {
+ function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random"),
+ seed = NULL, initSteps = 2, tol = 1E-4) {
formula <- paste(deparse(formula), collapse = "")
initMode <- match.arg(initMode)
+ if (!is.null(seed)) {
+ seed <- as.character(as.integer(seed))
+ }
jobj <- callJStatic("org.apache.spark.ml.r.KMeansWrapper", "fit", data@sdf, formula,
- as.integer(k), as.integer(maxIter), initMode)
+ as.integer(k), as.integer(maxIter), initMode, seed,
+ as.integer(initSteps), as.numeric(tol))
new("KMeansModel", jobj = jobj)
})
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
index 1980fffd80..f013991002 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
@@ -132,6 +132,26 @@ test_that("spark.kmeans", {
expect_true(summary2$is.loaded)
unlink(modelPath)
+
+ # Test Kmeans on dataset that is sensitive to seed value
+ col1 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
+ col2 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
+ col3 <- c(1, 2, 3, 4, 0, 1, 2, 3, 4, 0)
+ cols <- as.data.frame(cbind(col1, col2, col3))
+ df <- createDataFrame(cols)
+
+ model1 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10,
+ initMode = "random", seed = 1, tol = 1E-5)
+ model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10,
+ initMode = "random", seed = 22222, tol = 1E-5)
+
+ fitted.model1 <- fitted(model1)
+ fitted.model2 <- fitted(model2)
+ # The predicted clusters are different
+ expect_equal(sort(collect(distinct(select(fitted.model1, "prediction")))$prediction),
+ c(0, 1, 2, 3))
+ expect_equal(sort(collect(distinct(select(fitted.model2, "prediction")))$prediction),
+ c(0, 1, 2))
})
test_that("spark.lda with libsvm", {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index ea9458525a..a1fefd31c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -68,7 +68,10 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
formula: String,
k: Int,
maxIter: Int,
- initMode: String): KMeansWrapper = {
+ initMode: String,
+ seed: String,
+ initSteps: Int,
+ tol: Double): KMeansWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
@@ -87,6 +90,10 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
.setMaxIter(maxIter)
.setInitMode(initMode)
.setFeaturesCol(rFormula.getFeaturesCol)
+ .setInitSteps(initSteps)
+ .setTol(tol)
+
+ if (seed != null && seed.length > 0) kMeans.setSeed(seed.toInt)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, kMeans))