diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-03-20 14:44:21 -0400 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-20 14:44:21 -0400 |
commit | 48866f789712b0cdbaf76054d1014c6df032fff1 (patch) | |
tree | 23daedb637cf736716ccf60515dffbab755a04ec /python | |
parent | a74564591f1c824f9eed516ae79e079b355fd32b (diff) | |
download | spark-48866f789712b0cdbaf76054d1014c6df032fff1.tar.gz spark-48866f789712b0cdbaf76054d1014c6df032fff1.tar.bz2 spark-48866f789712b0cdbaf76054d1014c6df032fff1.zip |
[SPARK-6095] [MLLIB] Support model save/load in Python's linear models
For Python's linear models, weights and intercept are stored in Python.
This PR implements Python's linear models sava/load functions which do the same thing as scala.
It can also make model import/export cross languages.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #5016 from yanboliang/spark-6095 and squashes the following commits:
d9bb824 [Yanbo Liang] fix python style
b3813ca [Yanbo Liang] linear model save/load for Python reuse the Scala implementation
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/classification.py | 58 | ||||
-rw-r--r-- | python/pyspark/mllib/regression.py | 84 | ||||
-rw-r--r-- | python/pyspark/mllib/util.py | 6 |
3 files changed, 145 insertions, 3 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index e476517370..b66159c5bf 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -21,7 +21,7 @@ import numpy from numpy import array from pyspark import RDD -from pyspark.mllib.common import callMLlibFunc +from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py from pyspark.mllib.linalg import SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper @@ -99,6 +99,18 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): 1 >>> lrm.predict(SparseVector(2, {0: 1.0})) 0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LogisticRegressionModel.load(sc, path) + >>> sameModel.predict(array([0.0, 1.0])) + 1 + >>> sameModel.predict(SparseVector(2, {0: 1.0})) + 0 + >>> try: + ... os.removedirs(path) + ... except: + ... pass """ def __init__(self, weights, intercept): super(LogisticRegressionModel, self).__init__(weights, intercept) @@ -124,6 +136,22 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): else: return 1 if prob > self._threshold else 0 + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + threshold = java_model.getThreshold().get() + model = LogisticRegressionModel(weights, intercept) + model.setThreshold(threshold) + return model + class LogisticRegressionWithSGD(object): @@ -243,6 +271,18 @@ class SVMModel(LinearBinaryClassificationModel): 1 >>> svm.predict(SparseVector(2, {0: -1.0})) 0 + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> svm.save(sc, path) + >>> sameModel = SVMModel.load(sc, path) + >>> sameModel.predict(SparseVector(2, {1: 1.0})) + 1 + >>> sameModel.predict(SparseVector(2, {0: -1.0})) + 0 + >>> try: + ... os.removedirs(path) + ... except: + ... pass """ def __init__(self, weights, intercept): super(SVMModel, self).__init__(weights, intercept) @@ -263,6 +303,22 @@ class SVMModel(LinearBinaryClassificationModel): else: return 1 if margin > self._threshold else 0 + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.classification.SVMModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + threshold = java_model.getThreshold().get() + model = SVMModel(weights, intercept) + model.setThreshold(threshold) + return model + class SVMWithSGD(object): diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 0c21ad5787..015a786011 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,8 +18,9 @@ import numpy as np from numpy import array -from pyspark.mllib.common import callMLlibFunc, inherit_doc +from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.util import Saveable, Loader __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'LinearRegressionWithSGD', @@ -114,6 +115,20 @@ class LinearRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LinearRegressionModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -126,6 +141,19 @@ class LinearRegressionModel(LinearRegressionModelBase): >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LinearRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = LinearRegressionModel(weights, intercept) + return model # train_func should take two parameters, namely data and initial_weights, and @@ -199,6 +227,20 @@ class LassoModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = LassoModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -211,6 +253,19 @@ class LassoModel(LinearRegressionModelBase): >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.LassoModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = LassoModel(weights, intercept) + return model class LassoWithSGD(object): @@ -246,6 +301,20 @@ class RidgeRegressionModel(LinearRegressionModelBase): True >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> lrm.save(sc, path) + >>> sameModel = RidgeRegressionModel.load(sc, path) + >>> abs(sameModel.predict(np.array([0.0])) - 0) < 0.5 + True + >>> abs(sameModel.predict(np.array([1.0])) - 1) < 0.5 + True + >>> abs(sameModel.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 + True + >>> try: + ... os.removedirs(path) + ... except: + ... pass >>> data = [ ... LabeledPoint(0.0, SparseVector(1, {0: 0.0})), ... LabeledPoint(1.0, SparseVector(1, {0: 1.0})), @@ -258,6 +327,19 @@ class RidgeRegressionModel(LinearRegressionModelBase): >>> abs(lrm.predict(SparseVector(1, {0: 1.0})) - 1) < 0.5 True """ + def save(self, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel( + _py2java(sc, self._coeff), self.intercept) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.RidgeRegressionModel.load( + sc._jsc.sc(), path) + weights = _java2py(sc, java_model.weights()) + intercept = java_model.intercept() + model = RidgeRegressionModel(weights, intercept) + return model class RidgeRegressionWithSGD(object): diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index e877c720ac..c5c3468eb9 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -20,7 +20,6 @@ import warnings from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector -from pyspark.mllib.regression import LabeledPoint class MLUtils(object): @@ -50,6 +49,7 @@ class MLUtils(object): @staticmethod def _convert_labeled_point_to_libsvm(p): """Converts a LabeledPoint to a string in LIBSVM format.""" + from pyspark.mllib.regression import LabeledPoint assert isinstance(p, LabeledPoint) items = [str(p.label)] v = _convert_to_vector(p.features) @@ -92,6 +92,7 @@ class MLUtils(object): >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils + >>> from pyspark.mllib.regression import LabeledPoint >>> tempFile = NamedTemporaryFile(delete=True) >>> tempFile.write("+1 1:1.0 3:2.0 5:3.0\\n-1\\n-1 2:4.0 4:5.0 6:6.0") >>> tempFile.flush() @@ -110,6 +111,7 @@ class MLUtils(object): >>> print examples[2] (-1.0,(6,[1,3,5],[4.0,5.0,6.0])) """ + from pyspark.mllib.regression import LabeledPoint if multiclass is not None: warnings.warn("deprecated", DeprecationWarning) @@ -130,6 +132,7 @@ class MLUtils(object): >>> from tempfile import NamedTemporaryFile >>> from fileinput import input + >>> from pyspark.mllib.regression import LabeledPoint >>> from glob import glob >>> from pyspark.mllib.util import MLUtils >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, 1.23), (2, 4.56)])), \ @@ -156,6 +159,7 @@ class MLUtils(object): >>> from tempfile import NamedTemporaryFile >>> from pyspark.mllib.util import MLUtils + >>> from pyspark.mllib.regression import LabeledPoint >>> examples = [LabeledPoint(1.1, Vectors.sparse(3, [(0, -1.23), (2, 4.56e-7)])), \ LabeledPoint(0.0, Vectors.dense([1.01, 2.02, 3.03]))] >>> tempFile = NamedTemporaryFile(delete=True) |