aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 3436a28b29..32ed48e103 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -419,6 +419,17 @@ class ListTests(MLlibTestCase):
as NumPy arrays.
"""
+ def test_bisecting_kmeans(self):
+ from pyspark.mllib.clustering import BisectingKMeans
+ data = array([0.0, 0.0, 1.0, 1.0, 9.0, 8.0, 8.0, 9.0]).reshape(4, 2)
+ bskm = BisectingKMeans()
+ model = bskm.train(sc.parallelize(data, 2), k=4)
+ p = array([0.0, 0.0])
+ rdd_p = self.sc.parallelize([p])
+ self.assertEqual(model.predict(p), model.predict(rdd_p).first())
+ self.assertEqual(model.computeCost(p), model.computeCost(rdd_p))
+ self.assertEqual(model.k, len(model.clusterCenters))
+
def test_kmeans(self):
from pyspark.mllib.clustering import KMeans
data = [