aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-03-20 14:44:21 -0400
committerXiangrui Meng <meng@databricks.com>2015-03-20 14:44:21 -0400
commit48866f789712b0cdbaf76054d1014c6df032fff1 (patch)
tree23daedb637cf736716ccf60515dffbab755a04ec /python
parenta74564591f1c824f9eed516ae79e079b355fd32b (diff)
downloadspark-48866f789712b0cdbaf76054d1014c6df032fff1.tar.gz
spark-48866f789712b0cdbaf76054d1014c6df032fff1.tar.bz2
spark-48866f789712b0cdbaf76054d1014c6df032fff1.zip
[SPARK-6095] [MLLIB] Support model save/load in Python's linear models
For Python's linear models, weights and intercept are stored in Python. This PR implements Python's linear models sava/load functions which do the same thing as scala. It can also make model import/export cross languages. Author: Yanbo Liang <ybliang8@gmail.com> Closes #5016 from yanboliang/spark-6095 and squashes the following commits: d9bb824 [Yanbo Liang] fix python style b3813ca [Yanbo Liang] linear model save/load for Python reuse the Scala implementation
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/classification.py58
-rw-r--r--python/pyspark/mllib/regression.py84
-rw-r--r--python/pyspark/mllib/util.py6
3 files changed, 145 insertions, 3 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index e476517370..b66159c5bf 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -21,7 +21,7 @@ import numpy
from numpy import array
from pyspark import RDD
-from pyspark.mllib.common import callMLlibFunc
+from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
@@ -99,6 +99,18 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
1
>>> lrm.predict(SparseVector(2, {0: 1.0}))
0
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LogisticRegressionModel.load(sc, path)
+ >>> sameModel.predict(array([0.0, 1.0]))
+ 1
+ >>> sameModel.predict(SparseVector(2, {0: 1.0}))
+ 0
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
"""
def __init__(self, weights, intercept):
super(LogisticRegressionModel, self).__init__(weights, intercept)
@@ -124,6 +136,22 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
else:
return 1 if prob > self._threshold else 0
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ threshold = java_model.getThreshold().get()
+ model = LogisticRegressionModel(weights, intercept)
+ model.setThreshold(threshold)
+ return model
+
class LogisticRegressionWithSGD(object):
@@ -243,6 +271,18 @@ class SVMModel(LinearBinaryClassificationModel):
1
>>> svm.predict(SparseVector(2, {0: -1.0}))
0
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> svm.save(sc, path)
+ >>> sameModel = SVMModel.load(sc, path)
+ >>> sameModel.predict(SparseVector(2, {1: 1.0}))
+ 1
+ >>> sameModel.predict(SparseVector(2, {0: -1.0}))
+ 0
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
"""
def __init__(self, weights, intercept):
super(SVMModel, self).__init__(weights, intercept)
@@ -263,6 +303,22 @@ class SVMModel(LinearBinaryClassificationModel):
else:
return 1 if margin > self._threshold else 0
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ threshold = java_model.getThreshold().get()
+ model = SVMModel(weights, intercept)
+ model.setThreshold(threshold)
+ return model
+
class SVMWithSGD(object):
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 0c21ad5787..015a786011 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,8 +18,9 @@
import numpy as np
from numpy import array
-from pyspark.mllib.common import callMLlibFunc, inherit_doc
+from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
+from pyspark.mllib.util import Saveable, Loader
__all__ = ['LabeledPoint', 'LinearModel',
'LinearRegressionModel', 'LinearRegressionWithSGD',
@@ -114,6 +115,20 @@ class LinearRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LinearRegressionModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -126,6 +141,19 @@ class LinearRegressionModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = LinearRegressionModel(weights, intercept)
+ return model
# train_func should take two parameters, namely data and initial_weights, and
@@ -199,6 +227,20 @@ class LassoModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LassoModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -211,6 +253,19 @@ class LassoModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = LassoModel(weights, intercept)
+ return model
class LassoWithSGD(object):
@@ -246,6 +301,20 @@ class RidgeRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = RidgeRegressionModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -258,6 +327,19 @@ class RidgeRegressionModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = RidgeRegressionModel(weights, intercept)
+ return model
class RidgeRegressionWithSGD(object):
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index e877c720ac..c5c3468eb9 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -20,7 +20,6 @@ import warnings
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc
from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector
-from pyspark.mllib.regression import LabeledPoint
class MLUtils(object):
@@ -50,6 +49,7 @@ class MLUtils(object):
@staticmethod
def _convert_labeled_point_to_libsvm(p):
"""Converts a LabeledPoint to a string in LIBSVM format."""
+ from pyspark.mllib.regression import LabeledPoint
assert isinstance(p, LabeledPoint)
items = [str(p.label)]
v = _convert_to_vector(p.features)
@@ -92,6 +92,7 @@ class MLUtils(object):
>>> from tempfile import NamedTemporaryFile
>>> from pyspark.mllib.util import MLUtils
+ >>> from pyspark.mllib.regression import LabeledPoint
>>> tempFile = NamedTemporaryFile(delete=True)
>>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0")
>>> tempFile.flush()
@@ -110,6 +111,7 @@ class MLUtils(object):
>>> print examples[2]
(-1.0,(6,[1,3,5],[4.0,5.0,6.0]))
"""
+ from pyspark.mllib.regression import LabeledPoint
if multiclass is not None:
warnings.warn("deprecated", DeprecationWarning)
@@ -130,6 +132,7 @@ class MLUtils(object):
>>> from tempfile import NamedTemporaryFile
>>> from fileinput import input
+ >>> from pyspark.mllib.regression import LabeledPoint
>>> from glob import glob
>>> from pyspark.mllib.util import MLUtils
>>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \
@@ -156,6 +159,7 @@ class MLUtils(object):
>>> from tempfile import NamedTemporaryFile
>>> from pyspark.mllib.util import MLUtils
+ >>> from pyspark.mllib.regression import LabeledPoint
>>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \
LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))]
>>> tempFile = NamedTemporaryFile(delete=True)