aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/clustering.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/clustering.py')
-rw-r--r--python/pyspark/mllib/clustering.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 8cf20e591a..30862918c3 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -28,12 +28,12 @@ class KMeansModel(object):
"""A clustering model derived from the k-means method.
>>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2)
- >>> clusters = KMeans.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random")
+ >>> clusters = KMeans.train(sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random")
>>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0]))
True
>>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0]))
True
- >>> clusters = KMeans.train(sc, sc.parallelize(data), 2)
+ >>> clusters = KMeans.train(sc.parallelize(data), 2)
"""
def __init__(self, centers_):
self.centers = centers_
@@ -52,12 +52,13 @@ class KMeansModel(object):
class KMeans(object):
@classmethod
- def train(cls, sc, data, k, maxIterations=100, runs=1,
- initialization_mode="k-means||"):
+ def train(cls, data, k, maxIterations=100, runs=1,
+ initializationMode="k-means||"):
"""Train a k-means clustering model."""
+ sc = data.context
dataBytes = _get_unmangled_double_vector_rdd(data)
ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd,
- k, maxIterations, runs, initialization_mode)
+ k, maxIterations, runs, initializationMode)
if len(ans) != 1:
raise RuntimeError("JVM call result had unexpected length")
elif type(ans[0]) != bytearray: