aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala9
1 files changed, 8 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index ea9458525a..a1fefd31c0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -68,7 +68,10 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
formula: String,
k: Int,
maxIter: Int,
- initMode: String): KMeansWrapper = {
+ initMode: String,
+ seed: String,
+ initSteps: Int,
+ tol: Double): KMeansWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
@@ -87,6 +90,10 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
.setMaxIter(maxIter)
.setInitMode(initMode)
.setFeaturesCol(rFormula.getFeaturesCol)
+ .setInitSteps(initSteps)
+ .setTol(tol)
+
+ if (seed != null && seed.length > 0) kMeans.setSeed(seed.toInt)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, kMeans))