aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/clustering.py
diff options
context:
space:
mode:
authornate.crosswhite <nate.crosswhite@stresearch.com>2015-01-21 10:32:10 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-21 10:32:10 -0800
commit7450a992b3b543a373c34fc4444a528954ac4b4a (patch)
tree1e3c63168367b3a25335f34dc0d8a58ffa39477f /python/pyspark/mllib/clustering.py
parentaa1e22b17b4ce885febe6970a2451c7d17d0acfb (diff)
downloadspark-7450a992b3b543a373c34fc4444a528954ac4b4a.tar.gz
spark-7450a992b3b543a373c34fc4444a528954ac4b4a.tar.bz2
spark-7450a992b3b543a373c34fc4444a528954ac4b4a.zip
[SPARK-4749] [mllib]: Allow initializing KMeans clusters using a seed
This implements the functionality for SPARK-4749 and provides units tests in Scala and PySpark Author: nate.crosswhite <nate.crosswhite@stresearch.com> Author: nxwhite-str <nxwhite-str@users.noreply.github.com> Author: Xiangrui Meng <meng@databricks.com> Closes #3610 from nxwhite-str/master and squashes the following commits: a2ebbd3 [nxwhite-str] Merge pull request #1 from mengxr/SPARK-4749-kmeans-seed 7668124 [Xiangrui Meng] minor updates f8d5928 [nate.crosswhite] Addressing PR issues 277d367 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 9156a57 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 5d087b4 [nate.crosswhite] Adding KMeans train with seed and Scala unit test 616d111 [nate.crosswhite] Merge remote-tracking branch 'upstream/master' 35c1884 [nate.crosswhite] Add kmeans initial seed to pyspark API
Diffstat (limited to 'python/pyspark/mllib/clustering.py')
-rw-r--r--python/pyspark/mllib/clustering.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index e2492eef5b..6b713aa393 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -78,10 +78,10 @@ class KMeansModel(object):
class KMeans(object):
@classmethod
- def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
+ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||", seed=None):
"""Train a k-means clustering model."""
model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
- runs, initializationMode)
+ runs, initializationMode, seed)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])