diff options
author | nate.crosswhite <nate.crosswhite@stresearch.com> | 2015-01-21 10:32:10 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-01-21 10:32:10 -0800 |
commit | 7450a992b3b543a373c34fc4444a528954ac4b4a (patch) | |
tree | 1e3c63168367b3a25335f34dc0d8a58ffa39477f /python | |
parent | aa1e22b17b4ce885febe6970a2451c7d17d0acfb (diff) | |
download | spark-7450a992b3b543a373c34fc4444a528954ac4b4a.tar.gz spark-7450a992b3b543a373c34fc4444a528954ac4b4a.tar.bz2 spark-7450a992b3b543a373c34fc4444a528954ac4b4a.zip |
[SPARK-4749] [mllib]: Allow initializing KMeans clusters using a seed
This implements the functionality for SPARK-4749 and provides units tests in Scala and PySpark
Author: nate.crosswhite <nate.crosswhite@stresearch.com>
Author: nxwhite-str <nxwhite-str@users.noreply.github.com>
Author: Xiangrui Meng <meng@databricks.com>
Closes #3610 from nxwhite-str/master and squashes the following commits:
a2ebbd3 [nxwhite-str] Merge pull request #1 from mengxr/SPARK-4749-kmeans-seed
7668124 [Xiangrui Meng] minor updates
f8d5928 [nate.crosswhite] Addressing PR issues
277d367 [nate.crosswhite] Merge remote-tracking branch 'upstream/master'
9156a57 [nate.crosswhite] Merge remote-tracking branch 'upstream/master'
5d087b4 [nate.crosswhite] Adding KMeans train with seed and Scala unit test
616d111 [nate.crosswhite] Merge remote-tracking branch 'upstream/master'
35c1884 [nate.crosswhite] Add kmeans initial seed to pyspark API
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/clustering.py | 4 | ||||
-rw-r--r-- | python/pyspark/mllib/tests.py | 17 |
2 files changed, 18 insertions, 3 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index e2492eef5b..6b713aa393 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -78,10 +78,10 @@ class KMeansModel(object): class KMeans(object): @classmethod - def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"): + def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None): """Train a k-means clustering model.""" model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations, - runs, initializationMode) + runs, initializationMode, seed) centers = callJavaFunc(rdd.context, model.clusterCenters) return KMeansModel([c.toArray() for c in centers]) diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 140c22b5fd..f48e3d6dac 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -140,7 +140,7 @@ class ListTests(PySparkTestCase): as NumPy arrays. """ - def test_clustering(self): + def test_kmeans(self): from pyspark.mllib.clustering import KMeans data = [ [0, 1.1], @@ -152,6 +152,21 @@ class ListTests(PySparkTestCase): self.assertEquals(clusters.predict(data[0]), clusters.predict(data[1])) self.assertEquals(clusters.predict(data[2]), clusters.predict(data[3])) + def test_kmeans_deterministic(self): + from pyspark.mllib.clustering import KMeans + X = range(0, 100, 10) + Y = range(0, 100, 10) + data = [[x, y] for x, y in zip(X, Y)] + clusters1 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + clusters2 = KMeans.train(self.sc.parallelize(data), + 3, initializationMode="k-means||", seed=42) + centers1 = clusters1.centers + centers2 = clusters2.centers + for c1, c2 in zip(centers1, centers2): + # TODO: Allow small numeric difference. + self.assertTrue(array_equal(c1, c2)) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree |