diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-03-20 14:44:21 -0400 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-20 14:44:21 -0400 |
commit | 48866f789712b0cdbaf76054d1014c6df032fff1 (patch) | |
tree | 23daedb637cf736716ccf60515dffbab755a04ec /python/pyspark/mllib/regression.py | |
parent | a74564591f1c824f9eed516ae79e079b355fd32b (diff) | |
download | spark-48866f789712b0cdbaf76054d1014c6df032fff1.tar.gz spark-48866f789712b0cdbaf76054d1014c6df032fff1.tar.bz2 spark-48866f789712b0cdbaf76054d1014c6df032fff1.zip |
[SPARK-6095] [MLLIB] Support model save/load in Python's linear models
For Python's linear models, weights and intercept are stored in Python.
This PR implements Python's linear models sava/load functions which do the same thing as scala.
It can also make model import/export cross languages.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #5016 from yanboliang/spark-6095 and squashes the following commits:
d9bb824 [Yanbo Liang] fix python style
b3813ca [Yanbo Liang] linear model save/load for Python reuse the Scala implementation
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r-- | python/pyspark/mllib/regression.py | 84 |
1 files changed, 83 insertions, 1 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 0c21ad5787..015a786011 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,8 +18,9 @@ import numpy as np from numpy import array -from pyspark.mllib.common import callMLlibFunc, inherit_doc +from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.util import Saveable, Loader __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'LinearRegressionWithSGD', @@ -114,6 +115,20 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LinearRegressionModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -126,6 +141,19 @@ class LinearRegressionModel(LinearRegressionModelBase): >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = LinearRegressionModel(weights, intercept) + return model # train_func should take two parameters, namely data and initial_weights, and @@ -199,6 +227,20 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LassoModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -211,6 +253,19 @@ class LassoModel(LinearRegressionModelBase): >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = LassoModel(weights, intercept) + return model class LassoWithSGD(object): @@ -246,6 +301,20 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = RidgeRegressionModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -258,6 +327,19 @@ class RidgeRegressionModel(LinearRegressionModelBase): >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = RidgeRegressionModel(weights, intercept) + return model class RidgeRegressionWithSGD(object): |