aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-03-20 14:53:59 -0400
committerXiangrui Meng <meng@databricks.com>2015-03-20 14:53:59 -0400
commit25636d9867c6bc901463b6b227cb444d701cfdd1 (patch)
tree0decdcff3c8d20c399d792cfde375f9c737acd8d /python
parent5e6ad24ff645a9b0f63d9c0f17193550963aa0a7 (diff)
downloadspark-25636d9867c6bc901463b6b227cb444d701cfdd1.tar.gz
spark-25636d9867c6bc901463b6b227cb444d701cfdd1.tar.bz2
spark-25636d9867c6bc901463b6b227cb444d701cfdd1.zip
[Spark 6096][MLlib] Add Naive Bayes load save methods in Python
See [SPARK-6096](https://issues.apache.org/jira/browse/SPARK-6096). Author: Xusen Yin <yinxusen@gmail.com> Closes #5090 from yinxusen/SPARK-6096 and squashes the following commits: bd0fea5 [Xusen Yin] fix style problem, etc. 3fd41f2 [Xusen Yin] use hanging indent in Python style e83803d [Xusen Yin] fix Python style d6dbde5 [Xusen Yin] fix python call java error a054bb3 [Xusen Yin] add save load for NaiveBayes python
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/classification.py31
1 files changed, 30 insertions, 1 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index b66159c5bf..6766f3ebb8 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -24,6 +24,7 @@ from pyspark import RDD
from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
+from pyspark.mllib.util import Saveable, Loader, inherit_doc
__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
@@ -359,7 +360,8 @@ class SVMWithSGD(object):
return _regression_train_wrapper(train, SVMModel, data, initialWeights)
-class NaiveBayesModel(object):
+@inherit_doc
+class NaiveBayesModel(Saveable, Loader):
"""
Model for Naive Bayes classifiers.
@@ -390,6 +392,16 @@ class NaiveBayesModel(object):
0.0
>>> model.predict(SparseVector(2, {0: 1.0}))
1.0
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = NaiveBayesModel.load(sc, path)
+ >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0}))
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except OSError:
+ ... pass
"""
def __init__(self, labels, pi, theta):
@@ -404,6 +416,23 @@ class NaiveBayesModel(object):
x = _convert_to_vector(x)
return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))]
+ def save(self, sc, path):
+ java_labels = _py2java(sc, self.labels.tolist())
+ java_pi = _py2java(sc, self.pi.tolist())
+ java_theta = _py2java(sc, self.theta.tolist())
+ java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel(
+ java_labels, java_pi, java_theta)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load(
+ sc._jsc.sc(), path)
+ py_labels = _java2py(sc, java_model.labels())
+ py_pi = _java2py(sc, java_model.pi())
+ py_theta = _java2py(sc, java_model.theta())
+ return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta))
+
class NaiveBayes(object):