aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/clustering.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/clustering.py')
-rw-r--r--python/pyspark/mllib/clustering.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index fe4c4cc509..e2492eef5b 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -16,7 +16,7 @@
#
from pyspark import SparkContext
-from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, callJavaFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
__all__ = ['KMeansModel', 'KMeans']
@@ -80,10 +80,8 @@ class KMeans(object):
@classmethod
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
"""Train a k-means clustering model."""
- # cache serialized data to avoid objects over head in JVM
- jcached = _to_java_object_rdd(rdd.map(_convert_to_vector), cache=True)
- model = callMLlibFunc("trainKMeansModel", jcached, k, maxIterations, runs,
- initializationMode)
+ model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
+ runs, initializationMode)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])