aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/regression.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-03-20 14:44:21 -0400
committerXiangrui Meng <meng@databricks.com>2015-03-20 14:44:21 -0400
commit48866f789712b0cdbaf76054d1014c6df032fff1 (patch)
tree23daedb637cf736716ccf60515dffbab755a04ec /python/pyspark/mllib/regression.py
parenta74564591f1c824f9eed516ae79e079b355fd32b (diff)
downloadspark-48866f789712b0cdbaf76054d1014c6df032fff1.tar.gz
spark-48866f789712b0cdbaf76054d1014c6df032fff1.tar.bz2
spark-48866f789712b0cdbaf76054d1014c6df032fff1.zip
[SPARK-6095] [MLLIB] Support model save/load in Python's linear models
For Python's linear models, weights and intercept are stored in Python. This PR implements Python's linear models sava/load functions which do the same thing as scala. It can also make model import/export cross languages. Author: Yanbo Liang <ybliang8@gmail.com> Closes #5016 from yanboliang/spark-6095 and squashes the following commits: d9bb824 [Yanbo Liang] fix python style b3813ca [Yanbo Liang] linear model save/load for Python reuse the Scala implementation
Diffstat (limited to 'python/pyspark/mllib/regression.py')
-rw-r--r--python/pyspark/mllib/regression.py84
1 files changed, 83 insertions, 1 deletions
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 0c21ad5787..015a786011 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,8 +18,9 @@
import numpy as np
from numpy import array
-from pyspark.mllib.common import callMLlibFunc, inherit_doc
+from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
+from pyspark.mllib.util import Saveable, Loader
__all__ = ['LabeledPoint', 'LinearModel',
'LinearRegressionModel', 'LinearRegressionWithSGD',
@@ -114,6 +115,20 @@ class LinearRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LinearRegressionModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -126,6 +141,19 @@ class LinearRegressionModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = LinearRegressionModel(weights, intercept)
+ return model
# train_func should take two parameters, namely data and initial_weights, and
@@ -199,6 +227,20 @@ class LassoModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = LassoModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -211,6 +253,19 @@ class LassoModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = LassoModel(weights, intercept)
+ return model
class LassoWithSGD(object):
@@ -246,6 +301,20 @@ class RidgeRegressionModel(LinearRegressionModelBase):
True
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lrm.save(sc, path)
+ >>> sameModel = RidgeRegressionModel.load(sc, path)
+ >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5
+ True
+ >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5
+ True
+ >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
+ True
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
>>> data = [
... LabeledPoint(0.0, SparseVector(1, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(1, {0: 1.0})),
@@ -258,6 +327,19 @@ class RidgeRegressionModel(LinearRegressionModelBase):
>>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5
True
"""
+ def save(self, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel(
+ _py2java(sc, self._coeff), self.intercept)
+ java_model.save(sc._jsc.sc(), path)
+
+ @classmethod
+ def load(cls, sc, path):
+ java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load(
+ sc._jsc.sc(), path)
+ weights = _java2py(sc, java_model.weights())
+ intercept = java_model.intercept()
+ model = RidgeRegressionModel(weights, intercept)
+ return model
class RidgeRegressionWithSGD(object):