aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/R/mllib_clustering.R16
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib_clustering.R15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala2
3 files changed, 23 insertions, 10 deletions
diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R
index 3b782ce41e..8823f90775 100644
--- a/R/pkg/R/mllib_clustering.R
+++ b/R/pkg/R/mllib_clustering.R
@@ -375,10 +375,13 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"
#' @param object a fitted k-means model.
#' @return \code{summary} returns summary information of the fitted model, which is a list.
-#' The list includes the model's \code{k} (number of cluster centers),
+#' The list includes the model's \code{k} (the configured number of cluster centers),
#' \code{coefficients} (model cluster centers),
-#' \code{size} (number of data points in each cluster), and \code{cluster}
-#' (cluster centers of the transformed data).
+#' \code{size} (number of data points in each cluster), \code{cluster}
+#' (cluster centers of the transformed data), {is.loaded} (whether the model is loaded
+#' from a saved file), and \code{clusterSize}
+#' (the actual number of cluster centers. When using initMode = "random",
+#' \code{clusterSize} may not equal to \code{k}).
#' @rdname spark.kmeans
#' @export
#' @note summary(KMeansModel) since 2.0.0
@@ -390,16 +393,17 @@ setMethod("summary", signature(object = "KMeansModel"),
coefficients <- callJMethod(jobj, "coefficients")
k <- callJMethod(jobj, "k")
size <- callJMethod(jobj, "size")
- coefficients <- t(matrix(unlist(coefficients), ncol = k))
+ clusterSize <- callJMethod(jobj, "clusterSize")
+ coefficients <- t(matrix(unlist(coefficients), ncol = clusterSize))
colnames(coefficients) <- unlist(features)
- rownames(coefficients) <- 1:k
+ rownames(coefficients) <- 1:clusterSize
cluster <- if (is.loaded) {
NULL
} else {
dataFrame(callJMethod(jobj, "cluster"))
}
list(k = k, coefficients = coefficients, size = size,
- cluster = cluster, is.loaded = is.loaded)
+ cluster = cluster, is.loaded = is.loaded, clusterSize = clusterSize)
})
# Predicted values based on a k-means model
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
index 28a6eeba2c..1661e987b7 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
@@ -196,13 +196,20 @@ test_that("spark.kmeans", {
model2 <- spark.kmeans(data = df, ~ ., k = 5, maxIter = 10,
initMode = "random", seed = 22222, tol = 1E-5)
- fitted.model1 <- fitted(model1)
- fitted.model2 <- fitted(model2)
+ summary.model1 <- summary(model1)
+ summary.model2 <- summary(model2)
+ cluster1 <- summary.model1$cluster
+ cluster2 <- summary.model2$cluster
+ clusterSize1 <- summary.model1$clusterSize
+ clusterSize2 <- summary.model2$clusterSize
+
# The predicted clusters are different
- expect_equal(sort(collect(distinct(select(fitted.model1, "prediction")))$prediction),
+ expect_equal(sort(collect(distinct(select(cluster1, "prediction")))$prediction),
c(0, 1, 2, 3))
- expect_equal(sort(collect(distinct(select(fitted.model2, "prediction")))$prediction),
+ expect_equal(sort(collect(distinct(select(cluster2, "prediction")))$prediction),
c(0, 1, 2))
+ expect_equal(clusterSize1, 4)
+ expect_equal(clusterSize2, 3)
})
test_that("spark.lda with libsvm", {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index a1fefd31c0..8d596863b4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -43,6 +43,8 @@ private[r] class KMeansWrapper private (
lazy val cluster: DataFrame = kMeansModel.summary.cluster
+ lazy val clusterSize: Int = kMeansModel.clusterCenters.size
+
def fitted(method: String): DataFrame = {
if (method == "centers") {
kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol)