diff options
author | Xusen Yin <yinxusen@gmail.com> | 2015-03-20 14:53:59 -0400 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-20 14:53:59 -0400 |
commit | 25636d9867c6bc901463b6b227cb444d701cfdd1 (patch) | |
tree | 0decdcff3c8d20c399d792cfde375f9c737acd8d /python | |
parent | 5e6ad24ff645a9b0f63d9c0f17193550963aa0a7 (diff) | |
download | spark-25636d9867c6bc901463b6b227cb444d701cfdd1.tar.gz spark-25636d9867c6bc901463b6b227cb444d701cfdd1.tar.bz2 spark-25636d9867c6bc901463b6b227cb444d701cfdd1.zip |
[Spark 6096][MLlib] Add Naive Bayes load save methods in Python
See [SPARK-6096](https://issues.apache.org/jira/browse/SPARK-6096).
Author: Xusen Yin <yinxusen@gmail.com>
Closes #5090 from yinxusen/SPARK-6096 and squashes the following commits:
bd0fea5 [Xusen Yin] fix style problem, etc.
3fd41f2 [Xusen Yin] use hanging indent in Python style
e83803d [Xusen Yin] fix Python style
d6dbde5 [Xusen Yin] fix python call java error
a054bb3 [Xusen Yin] add save load for NaiveBayes python
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/classification.py | 31 |
1 files changed, 30 insertions, 1 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index b66159c5bf..6766f3ebb8 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -24,6 +24,7 @@ from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper +from pyspark.mllib.util import Saveable, Loader, inherit_doc __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS', @@ -359,7 +360,8 @@ class SVMWithSGD(object): return _regression_train_wrapper(train, SVMModel, data, initialWeights) -class NaiveBayesModel(object): +@inherit_doc +class NaiveBayesModel(Saveable, Loader): """ Model for Naive Bayes classifiers. @@ -390,6 +392,16 @@ class NaiveBayesModel(object): 0.0 >>> model.predict(SparseVector(2, {0: 1.0})) 1.0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = NaiveBayesModel.load(sc, path) + >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0})) + True + >>> try: + ... os.removedirs(path) + ... except OSError: + ... pass """ def __init__(self, labels, pi, theta): @@ -404,6 +416,23 @@ class NaiveBayesModel(object): x = _convert_to_vector(x) return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))] + def save(self, sc, path): + java_labels = _py2java(sc, self.labels.tolist()) + java_pi = _py2java(sc, self.pi.tolist()) + java_theta = _py2java(sc, self.theta.tolist()) + java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel( + java_labels, java_pi, java_theta) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load( + sc._jsc.sc(), path) + py_labels = _java2py(sc, java_model.labels()) + py_pi = _java2py(sc, java_model.pi()) + py_theta = _java2py(sc, java_model.theta()) + return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta)) + class NaiveBayes(object): |