aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/ml/clustering.py15
-rw-r--r--python/pyspark/ml/tests.py9
2 files changed, 9 insertions, 15 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 60d1c9aaec..12afb88563 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -113,10 +113,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
def setK(self, value):
"""
Sets the value of :py:attr:`k`.
-
- >>> algo = KMeans().setK(10)
- >>> algo.getK()
- 10
"""
self._paramMap[self.k] = value
return self
@@ -132,13 +128,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
def setInitMode(self, value):
"""
Sets the value of :py:attr:`initMode`.
-
- >>> algo = KMeans()
- >>> algo.getInitMode()
- 'k-means||'
- >>> algo = algo.setInitMode("random")
- >>> algo.getInitMode()
- 'random'
"""
self._paramMap[self.initMode] = value
return self
@@ -154,10 +143,6 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
def setInitSteps(self, value):
"""
Sets the value of :py:attr:`initSteps`.
-
- >>> algo = KMeans().setInitSteps(10)
- >>> algo.getInitSteps()
- 10
"""
self._paramMap[self.initSteps] = value
return self
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 54806ee336..e93a4e157b 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -39,6 +39,7 @@ import tempfile
from pyspark.ml import Estimator, Model, Pipeline, Transformer
from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params
@@ -243,6 +244,14 @@ class ParamTests(PySparkTestCase):
"maxIter: max number of iterations (>= 0). (default: 10, current: 100)",
"seed: random seed. (default: 41, current: 43)"]))
+ def test_kmeans_param(self):
+ algo = KMeans()
+ self.assertEqual(algo.getInitMode(), "k-means||")
+ algo.setK(10)
+ self.assertEqual(algo.getK(), 10)
+ algo.setInitSteps(10)
+ self.assertEqual(algo.getInitSteps(), 10)
+
def test_hasseed(self):
noSeedSpecd = TestParams()
withSeedSpecd = TestParams(seed=42)