diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-03-17 12:14:40 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-17 12:14:40 -0700 |
commit | c94d0626471e209ab7ebfc588f9a2992946b7ed5 (patch) | |
tree | b5025c4412aac661e915fe25f3f0a9aaaa371a50 | |
parent | d9f3e01688ad0a8d5fc2419a948a682ad7d957c9 (diff) | |
download | spark-c94d0626471e209ab7ebfc588f9a2992946b7ed5.tar.gz spark-c94d0626471e209ab7ebfc588f9a2992946b7ed5.tar.bz2 spark-c94d0626471e209ab7ebfc588f9a2992946b7ed5.zip |
[SPARK-6226][MLLIB] add save/load in PySpark's KMeansModel
Use `_py2java` and `_java2py` to convert Python model to/from Java model. yinxusen
Author: Xiangrui Meng <meng@databricks.com>
Closes #5049 from mengxr/SPARK-6226-mengxr and squashes the following commits:
570ba81 [Xiangrui Meng] fix python style
b10b911 [Xiangrui Meng] add save/load in PySpark's KMeansModel
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala | 5 | ||||
-rw-r--r-- | python/pyspark/mllib/clustering.py | 28 | ||||
-rw-r--r-- | python/pyspark/mllib/common.py | 4 |
3 files changed, 32 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 707da537d2..e4e411a3c8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -17,6 +17,8 @@ package org.apache.spark.mllib.clustering +import scala.collection.JavaConverters._ + import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -34,6 +36,9 @@ import org.apache.spark.sql.Row */ class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable { + /** A Java-friendly constructor that takes an Iterable of Vectors. */ + def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray) + /** Total number of clusters. */ def k: Int = clusterCenters.length diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index 949db5705a..464f49aeee 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -19,14 +19,16 @@ from numpy import array from pyspark import RDD from pyspark import SparkContext -from pyspark.mllib.common import callMLlibFunc, callJavaFunc -from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector +from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py +from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.stat.distribution import MultivariateGaussian +from pyspark.mllib.util import Saveable, Loader, inherit_doc __all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture'] -class KMeansModel(object): +@inherit_doc +class KMeansModel(Saveable, Loader): """A clustering model derived from the k-means method. @@ -55,6 +57,16 @@ class KMeansModel(object): True >>> type(model.clusterCenters) <type 'list'> + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = KMeansModel.load(sc, path) + >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0]) + True + >>> try: + ... os.removedirs(path) + ... except OSError: + ... pass """ def __init__(self, centers): @@ -77,6 +89,16 @@ class KMeansModel(object): best_distance = distance return best + def save(self, sc, path): + java_centers = _py2java(sc, map(_convert_to_vector, self.centers)) + java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel.load(sc._jsc.sc(), path) + return KMeansModel(_java2py(sc, java_model.clusterCenters())) + class KMeans(object): diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py index 621591c26b..a539d2f284 100644 --- a/python/pyspark/mllib/common.py +++ b/python/pyspark/mllib/common.py @@ -70,8 +70,8 @@ def _py2java(sc, obj): obj = _to_java_object_rdd(obj) elif isinstance(obj, SparkContext): obj = obj._jsc - elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)): - obj = ListConverter().convert(obj, sc._gateway._gateway_client) + elif isinstance(obj, list): + obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client) elif isinstance(obj, JavaObject): pass elif isinstance(obj, (int, long, float, bool, basestring)): |