aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/classification.py31
1 files changed, 30 insertions, 1 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index b66159c5bf..6766f3ebb8 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -24,6 +24,7 @@ from pyspark import RDD
from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
+from pyspark.mllib.util import Saveable, Loader, inherit_doc
__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
@@ -359,7 +360,8 @@ class SVMWithSGD(object):
return _regression_train_wrapper(train, SVMModel, data, initialWeights)
-class NaiveBayesModel(object):
+@inherit_doc
+class NaiveBayesModel(Saveable, Loader):
"""
Model for Naive Bayes classifiers.
@@ -390,6 +392,16 @@ class NaiveBayesModel(object):
0.0
>>> model.predict(SparseVector(2, {0: 1.0}))
1.0
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = NaiveBayesModel.load(sc, path)
+ >>> sameModel.predict(SparseVector(2, {0: 1.0})) == model.predict(SparseVector(2, {0: 1.0}))
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except OSError:
+ ... pass
"""
def __init__(self, labels, pi, theta):
@@ -404,6 +416,23 @@ class NaiveBayesModel(object):
x = _convert_to_vector(x)
return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))]
+ def save(self, sc, path):
+ java_labels = _py2java(sc, self.labels.tolist())
+ java_pi = _py2java(sc, self.pi.tolist())
+ java_theta = _py2java(sc, self.theta.tolist())
+ java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel(
+ java_labels, java_pi, java_theta)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.NaiveBayesModel.load(
+ sc._jsc.sc(), path)
+ py_labels = _java2py(sc, java_model.labels())
+ py_pi = _java2py(sc, java_model.pi())
+ py_theta = _java2py(sc, java_model.theta())
+ return NaiveBayesModel(py_labels, py_pi, numpy.array(py_theta))
+
class NaiveBayes(object):