From 7b1457839bdac124a07fd6292f6263f0ded48880 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 5 May 2015 22:57:13 -0700 Subject: [SPARK-6267] [MLLIB] Python API for IsotonicRegression https://issues.apache.org/jira/browse/SPARK-6267 Author: Yanbo Liang Author: Xiangrui Meng Closes #5890 from yanboliang/spark-6267 and squashes the following commits: f20541d [Yanbo Liang] Merge pull request #3 from mengxr/SPARK-6267 7f202f9 [Xiangrui Meng] use Vector to have the best Python 2&3 compatibility 4bccfee [Yanbo Liang] fix doctest ec09412 [Yanbo Liang] fix typos 8214bbb [Yanbo Liang] fix code style 5c8ebe5 [Yanbo Liang] Python API for IsotonicRegression --- python/pyspark/mllib/regression.py | 73 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 2 deletions(-) (limited to 'python/pyspark/mllib/regression.py') diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 4bc6351bdf..41bde2ce3e 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -18,14 +18,16 @@ import numpy as np from numpy import array +from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py, inherit_doc -from pyspark.mllib.linalg import SparseVector, _convert_to_vector +from pyspark.mllib.linalg import SparseVector, Vectors, _convert_to_vector from pyspark.mllib.util import Saveable, Loader __all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'LinearRegressionWithSGD', 'RidgeRegressionModel', 'RidgeRegressionWithSGD', - 'LassoModel', 'LassoWithSGD'] + 'LassoModel', 'LassoWithSGD', 'IsotonicRegressionModel', + 'IsotonicRegression'] class LabeledPoint(object): @@ -396,6 +398,73 @@ class RidgeRegressionWithSGD(object): return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights) +class IsotonicRegressionModel(Saveable, Loader): + + """Regression model for isotonic regression. + + >>> data = [(1, 0, 1), (2, 1, 1), (3, 2, 1), (1, 3, 1), (6, 4, 1), (17, 5, 1), (16, 6, 1)] + >>> irm = IsotonicRegression.train(sc.parallelize(data)) + >>> irm.predict(3) + 2.0 + >>> irm.predict(5) + 16.5 + >>> irm.predict(sc.parallelize([3, 5])).collect() + [2.0, 16.5] + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> irm.save(sc, path) + >>> sameModel = IsotonicRegressionModel.load(sc, path) + >>> sameModel.predict(3) + 2.0 + >>> sameModel.predict(5) + 16.5 + >>> try: + ... os.removedirs(path) + ... except OSError: + ... pass + """ + + def __init__(self, boundaries, predictions, isotonic): + self.boundaries = boundaries + self.predictions = predictions + self.isotonic = isotonic + + def predict(self, x): + if isinstance(x, RDD): + return x.map(lambda v: self.predict(v)) + return np.interp(x, self.boundaries, self.predictions) + + def save(self, sc, path): + java_boundaries = _py2java(sc, self.boundaries.tolist()) + java_predictions = _py2java(sc, self.predictions.tolist()) + java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel( + java_boundaries, java_predictions, self.isotonic) + java_model.save(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = sc._jvm.org.apache.spark.mllib.regression.IsotonicRegressionModel.load( + sc._jsc.sc(), path) + py_boundaries = _java2py(sc, java_model.boundaryVector()).toArray() + py_predictions = _java2py(sc, java_model.predictionVector()).toArray() + return IsotonicRegressionModel(py_boundaries, py_predictions, java_model.isotonic) + + +class IsotonicRegression(object): + """ + Run IsotonicRegression algorithm to obtain isotonic regression model. + + :param data: RDD of (label, feature, weight) tuples. + :param isotonic: Whether this is isotonic or antitonic. + """ + @classmethod + def train(cls, data, isotonic=True): + """Train a isotonic regression model on the given data.""" + boundaries, predictions = callMLlibFunc("trainIsotonicRegressionModel", + data.map(_convert_to_vector), bool(isotonic)) + return IsotonicRegressionModel(boundaries.toArray(), predictions.toArray(), isotonic) + + def _test(): import doctest from pyspark import SparkContext -- cgit v1.2.3