aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-02-11 15:53:45 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-11 15:53:45 -0800
commit2426eb3e167fece19831070594247e9481dbbe2a (patch)
tree088a6165c0c5822d6876714bc7ba121310b794c1 /python/pyspark/ml/tests.py
parentc8f667d7c1a0b02685e17b6f498879b05ced9b9d (diff)
downloadspark-2426eb3e167fece19831070594247e9481dbbe2a.tar.gz
spark-2426eb3e167fece19831070594247e9481dbbe2a.tar.bz2
spark-2426eb3e167fece19831070594247e9481dbbe2a.zip
[MINOR][ML][PYSPARK] Cleanup test cases of clustering.py
Test cases should be removed from annotation of ```setXXX``` function, otherwise it will be parts of [Python API docs](https://spark.apache.org/docs/latest/api/python/pyspark.ml.html#pyspark.ml.clustering.KMeans.setInitMode). cc mengxr jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Closes #10975 from yanboliang/clustering-cleanup.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 54806ee336..e93a4e157b 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -39,6 +39,7 @@ import tempfile
from pyspark.ml import Estimator, Model, Pipeline, Transformer
from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params
@@ -243,6 +244,14 @@ class ParamTests(PySparkTestCase):
"maxIter: max number of iterations (>= 0). (default: 10, current: 100)",
"seed: random seed. (default: 41, current: 43)"]))
+ def test_kmeans_param(self):
+ algo = KMeans()
+ self.assertEqual(algo.getInitMode(), "k-means||")
+ algo.setK(10)
+ self.assertEqual(algo.getK(), 10)
+ algo.setInitSteps(10)
+ self.assertEqual(algo.getInitSteps(), 10)
+
def test_hasseed(self):
noSeedSpecd = TestParams()
withSeedSpecd = TestParams(seed=42)