diff options
Diffstat (limited to 'python/pyspark/mllib/clustering.py')
-rw-r--r-- | python/pyspark/mllib/clustering.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 8cf20e591a..30862918c3 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -28,12 +28,12 @@ class KMeansModel(object): """A clustering model derived from the k-means method. >>> data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0]).reshape(4,2) - >>> clusters = KMeans.train(sc, sc.parallelize(data), 2, maxIterations=10, runs=30, initialization_mode="random") + >>> clusters = KMeans.train(sc.parallelize(data), 2, maxIterations=10, runs=30, initializationMode="random") >>> clusters.predict(array([0.0, 0.0])) == clusters.predict(array([1.0, 1.0])) True >>> clusters.predict(array([8.0, 9.0])) == clusters.predict(array([9.0, 8.0])) True - >>> clusters = KMeans.train(sc, sc.parallelize(data), 2) + >>> clusters = KMeans.train(sc.parallelize(data), 2) """ def __init__(self, centers_): self.centers = centers_ @@ -52,12 +52,13 @@ class KMeansModel(object): class KMeans(object): @classmethod - def train(cls, sc, data, k, maxIterations=100, runs=1, - initialization_mode="k-means||"): + def train(cls, data, k, maxIterations=100, runs=1, + initializationMode="k-means||"): """Train a k-means clustering model.""" + sc = data.context dataBytes = _get_unmangled_double_vector_rdd(data) ans = sc._jvm.PythonMLLibAPI().trainKMeansModel(dataBytes._jrdd, - k, maxIterations, runs, initialization_mode) + k, maxIterations, runs, initializationMode) if len(ans) != 1: raise RuntimeError("JVM call result had unexpected length") elif type(ans[0]) != bytearray: |