aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/clustering.py4
-rw-r--r--python/pyspark/mllib/tests.py17
2 files changed, 18 insertions, 3 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index e2492eef5b..6b713aa393 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -78,10 +78,10 @@ class KMeansModel(object):
class KMeans(object):
@classmethod
- def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
+ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None):
"""Train a k-means clustering model."""
model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
- runs, initializationMode)
+ runs, initializationMode, seed)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 140c22b5fd..f48e3d6dac 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -140,7 +140,7 @@ class ListTests(PySparkTestCase):
as NumPy arrays.
"""
- def test_clustering(self):
+ def test_kmeans(self):
from pyspark.mllib.clustering import KMeans
data = [
[0, 1.1],
@@ -152,6 +152,21 @@ class ListTests(PySparkTestCase):
self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1]))
self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3]))
+ def test_kmeans_deterministic(self):
+ from pyspark.mllib.clustering import KMeans
+ X = range(0, 100, 10)
+ Y = range(0, 100, 10)
+ data = [[x, y] for x, y in zip(X, Y)]
+ clusters1 = KMeans.train(self.sc.parallelize(data),
+ 3, initializationMode="k-means||", seed=42)
+ clusters2 = KMeans.train(self.sc.parallelize(data),
+ 3, initializationMode="k-means||", seed=42)
+ centers1 = clusters1.centers
+ centers2 = clusters2.centers
+ for c1, c2 in zip(centers1, centers2):
+ # TODO: Allow small numeric difference.
+ self.assertTrue(array_equal(c1, c2))
+
def test_classification(self):
from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes
from pyspark.mllib.tree import DecisionTree