aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-03-17 12:14:40 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-17 12:14:40 -0700
commitc94d0626471e209ab7ebfc588f9a2992946b7ed5 (patch)
treeb5025c4412aac661e915fe25f3f0a9aaaa371a50
parentd9f3e01688ad0a8d5fc2419a948a682ad7d957c9 (diff)
downloadspark-c94d0626471e209ab7ebfc588f9a2992946b7ed5.tar.gz
spark-c94d0626471e209ab7ebfc588f9a2992946b7ed5.tar.bz2
spark-c94d0626471e209ab7ebfc588f9a2992946b7ed5.zip
[SPARK-6226][MLLIB] add save/load in PySpark's KMeansModel
Use `_py2java` and `_java2py` to convert Python model to/from Java model. yinxusen Author: Xiangrui Meng <meng@databricks.com> Closes #5049 from mengxr/SPARK-6226-mengxr and squashes the following commits: 570ba81 [Xiangrui Meng] fix python style b10b911 [Xiangrui Meng] add save/load in PySpark's KMeansModel
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala5
-rw-r--r--python/pyspark/mllib/clustering.py28
-rw-r--r--python/pyspark/mllib/common.py4
3 files changed, 32 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
index 707da537d2..e4e411a3c8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala
@@ -17,6 +17,8 @@
package org.apache.spark.mllib.clustering
+import scala.collection.JavaConverters._
+
import org.json4s._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods._
@@ -34,6 +36,9 @@ import org.apache.spark.sql.Row
*/
class KMeansModel (val clusterCenters: Array[Vector]) extends Saveable with Serializable {
+ /** A Java-friendly constructor that takes an Iterable of Vectors. */
+ def this(centers: java.lang.Iterable[Vector]) = this(centers.asScala.toArray)
+
/** Total number of clusters. */
def k: Int = clusterCenters.length
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 949db5705a..464f49aeee 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -19,14 +19,16 @@ from numpy import array
from pyspark import RDD
from pyspark import SparkContext
-from pyspark.mllib.common import callMLlibFunc, callJavaFunc
-from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
+from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.stat.distribution import MultivariateGaussian
+from pyspark.mllib.util import Saveable, Loader, inherit_doc
__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']
-class KMeansModel(object):
+@inherit_doc
+class KMeansModel(Saveable, Loader):
"""A clustering model derived from the k-means method.
@@ -55,6 +57,16 @@ class KMeansModel(object):
True
>>> type(model.clusterCenters)
<type 'list'>
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = KMeansModel.load(sc, path)
+ >>> sameModel.predict(sparse_data[0]) == model.predict(sparse_data[0])
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except OSError:
+ ... pass
"""
def __init__(self, centers):
@@ -77,6 +89,16 @@ class KMeansModel(object):
best_distance = distance
return best
+ def save(self, sc, path):
+ java_centers = _py2java(sc, map(_convert_to_vector, self.centers))
+ java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel(java_centers)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.clustering.KMeansModel.load(sc._jsc.sc(), path)
+ return KMeansModel(_java2py(sc, java_model.clusterCenters()))
+
class KMeans(object):
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index 621591c26b..a539d2f284 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -70,8 +70,8 @@ def _py2java(sc, obj):
obj = _to_java_object_rdd(obj)
elif isinstance(obj, SparkContext):
obj = obj._jsc
- elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)):
- obj = ListConverter().convert(obj, sc._gateway._gateway_client)
+ elif isinstance(obj, list):
+ obj = ListConverter().convert([_py2java(sc, x) for x in obj], sc._gateway._gateway_client)
elif isinstance(obj, JavaObject):
pass
elif isinstance(obj, (int, long, float, bool, basestring)):