aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/clustering.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/clustering.py')
-rw-r--r--python/pyspark/mllib/clustering.py18
1 files changed, 12 insertions, 6 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 464f49aeee..abbb7cf60e 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -15,6 +15,12 @@
# limitations under the License.
#
+import sys
+import array as pyarray
+
+if sys.version > '3':
+ xrange = range
+
from numpy import array
from pyspark import RDD
@@ -55,8 +61,8 @@ class KMeansModel(Saveable, Loader):
True
>>> model.predict(sparse_data[2]) == model.predict(sparse_data[3])
True
- >>> type(model.clusterCenters)
- <type 'list'>
+ >>> isinstance(model.clusterCenters, list)
+ True
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
@@ -90,7 +96,7 @@ class KMeansModel(Saveable, Loader):
return best
def save(self, sc, path):
- java_centers = _py2java(sc, map(_convert_to_vector, self.centers))
+ java_centers = _py2java(sc, [_convert_to_vector(c) for c in self.centers])
java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers)
java_model.save(sc._jsc.sc(), path)
@@ -133,7 +139,7 @@ class GaussianMixtureModel(object):
... 5.7048, 4.6567, 5.5026,
... 4.5605, 5.2043, 6.2734]).reshape(5, 3))
>>> model = GaussianMixture.train(clusterdata_2, 2, convergenceTol=0.0001,
- ... maxIterations=150, seed=10)
+ ... maxIterations=150, seed=10)
>>> labels = model.predict(clusterdata_2).collect()
>>> labels[0]==labels[1]==labels[2]
True
@@ -168,8 +174,8 @@ class GaussianMixtureModel(object):
if isinstance(x, RDD):
means, sigmas = zip(*[(g.mu, g.sigma) for g in self.gaussians])
membership_matrix = callMLlibFunc("predictSoftGMM", x.map(_convert_to_vector),
- self.weights, means, sigmas)
- return membership_matrix
+ _convert_to_vector(self.weights), means, sigmas)
+ return membership_matrix.map(lambda x: pyarray.array('d', x))
class GaussianMixture(object):