aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-29 09:22:24 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-29 09:22:24 -0800
commite51b6eaa9e9c007e194d858195291b2b9fb27322 (patch)
treeb6af90c439154fe7514fd32e47a56a693ffd745a /python
parent55561e7693dd2a5bf3c7f8026c725421801fd0ec (diff)
downloadspark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.gz
spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.bz2
spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.zip
[SPARK-13032][ML][PYSPARK] PySpark support model export/import and take LinearRegression as example
* Implement ```MLWriter/MLWritable/MLReader/MLReadable``` for PySpark. * Making ```LinearRegression``` to support ```save/load``` as example. After this merged, the work for other transformers/estimators will be easy, then we can list and distribute the tasks to the community. cc mengxr jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #10469 from yanboliang/spark-11939.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/param/__init__.py24
-rw-r--r--python/pyspark/ml/regression.py30
-rw-r--r--python/pyspark/ml/tests.py36
-rw-r--r--python/pyspark/ml/util.py142
-rw-r--r--python/pyspark/ml/wrapper.py33
5 files changed, 236 insertions, 29 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 3da36d32c5..ea86d6aeb8 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -314,3 +314,27 @@ class Params(Identifiable):
if p in paramMap and to.hasParam(p.name):
to._set(**{p.name: paramMap[p]})
return to
+
+ def _resetUid(self, newUid):
+ """
+ Changes the uid of this instance. This updates both
+ the stored uid and the parent uid of params and param maps.
+ This is used by persistence (loading).
+ :param newUid: new uid to use
+ :return: same instance, but with the uid and Param.parent values
+ updated, including within param maps
+ """
+ self.uid = newUid
+ newDefaultParamMap = dict()
+ newParamMap = dict()
+ for param in self.params:
+ newParam = copy.copy(param)
+ newParam.parent = newUid
+ if param in self._defaultParamMap:
+ newDefaultParamMap[newParam] = self._defaultParamMap[param]
+ if param in self._paramMap:
+ newParamMap[newParam] = self._paramMap[param]
+ param.parent = newUid
+ self._defaultParamMap = newDefaultParamMap
+ self._paramMap = newParamMap
+ return self
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 74a2248ed0..20dc6c2db9 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -18,9 +18,9 @@
import warnings
from pyspark import since
-from pyspark.ml.util import keyword_only
-from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
+from pyspark.ml.util import *
+from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.mllib.common import inherit_doc
@@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
- HasStandardization, HasSolver, HasWeightCol):
+ HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable):
"""
Linear regression.
@@ -68,6 +68,25 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
Traceback (most recent call last):
...
TypeError: Method setParams forces keyword arguments.
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> lr_path = path + "/lr"
+ >>> lr.save(lr_path)
+ >>> lr2 = LinearRegression.load(lr_path)
+ >>> lr2.getMaxIter()
+ 5
+ >>> model_path = path + "/lr_model"
+ >>> model.save(model_path)
+ >>> model2 = LinearRegressionModel.load(model_path)
+ >>> model.coefficients[0] == model2.coefficients[0]
+ True
+ >>> model.intercept == model2.intercept
+ True
+ >>> from shutil import rmtree
+ >>> try:
+ ... rmtree(path)
+ ... except OSError:
+ ... pass
.. versionadded:: 1.4.0
"""
@@ -106,7 +125,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
return LinearRegressionModel(java_model)
-class LinearRegressionModel(JavaModel):
+class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
"""
Model fitted by LinearRegression.
@@ -821,9 +840,10 @@ class AFTSurvivalRegressionModel(JavaModel):
if __name__ == "__main__":
import doctest
+ import pyspark.ml.regression
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
- globs = globals().copy()
+ globs = pyspark.ml.regression.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.regression tests")
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index c45a159c46..54806ee336 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -34,18 +34,22 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-from pyspark.sql import DataFrame, SQLContext, Row
-from pyspark.sql.functions import rand
+from shutil import rmtree
+import tempfile
+
+from pyspark.ml import Estimator, Model, Pipeline, Transformer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import RegressionEvaluator
+from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
-from pyspark.ml.util import keyword_only
-from pyspark.ml import Estimator, Model, Pipeline, Transformer
-from pyspark.ml.feature import *
+from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, CrossValidatorModel
+from pyspark.ml.util import keyword_only
from pyspark.mllib.linalg import DenseVector
+from pyspark.sql import DataFrame, SQLContext, Row
+from pyspark.sql.functions import rand
+from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
class MockDataset(DataFrame):
@@ -405,6 +409,26 @@ class CrossValidatorTests(PySparkTestCase):
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+class PersistenceTest(PySparkTestCase):
+
+ def test_linear_regression(self):
+ lr = LinearRegression(maxIter=1)
+ path = tempfile.mkdtemp()
+ lr_path = path + "/lr"
+ lr.save(lr_path)
+ lr2 = LinearRegression.load(lr_path)
+ self.assertEqual(lr2.uid, lr2.maxIter.parent,
+ "Loaded LinearRegression instance uid (%s) did not match Param's uid (%s)"
+ % (lr2.uid, lr2.maxIter.parent))
+ self.assertEqual(lr._defaultParamMap[lr.maxIter], lr2._defaultParamMap[lr2.maxIter],
+ "Loaded LinearRegression instance default params did not match " +
+ "original defaults")
+ try:
+ rmtree(path)
+ except OSError:
+ pass
+
+
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index cee9d67b05..d7a813f56c 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -15,8 +15,27 @@
# limitations under the License.
#
-from functools import wraps
+import sys
import uuid
+from functools import wraps
+
+if sys.version > '3':
+ basestring = str
+
+from pyspark import SparkContext, since
+from pyspark.mllib.common import inherit_doc
+
+
+def _jvm():
+ """
+ Returns the JVM view associated with SparkContext. Must be called
+ after SparkContext is initialized.
+ """
+ jvm = SparkContext._jvm
+ if jvm:
+ return jvm
+ else:
+ raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
def keyword_only(func):
@@ -52,3 +71,124 @@ class Identifiable(object):
concatenates the class name, "_", and 12 random hex chars.
"""
return cls.__name__ + "_" + uuid.uuid4().hex[12:]
+
+
+@inherit_doc
+class JavaMLWriter(object):
+ """
+ .. note:: Experimental
+
+ Utility class that can save ML instances through their Scala implementation.
+
+ .. versionadded:: 2.0.0
+ """
+
+ def __init__(self, instance):
+ instance._transfer_params_to_java()
+ self._jwrite = instance._java_obj.write()
+
+ def save(self, path):
+ """Save the ML instance to the input path."""
+ if not isinstance(path, basestring):
+ raise TypeError("path should be a basestring, got type %s" % type(path))
+ self._jwrite.save(path)
+
+ def overwrite(self):
+ """Overwrites if the output path already exists."""
+ self._jwrite.overwrite()
+ return self
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for saving."""
+ self._jwrite.context(sqlContext._ssql_ctx)
+ return self
+
+
+@inherit_doc
+class MLWritable(object):
+ """
+ .. note:: Experimental
+
+ Mixin for ML instances that provide JavaMLWriter.
+
+ .. versionadded:: 2.0.0
+ """
+
+ def write(self):
+ """Returns an JavaMLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+
+@inherit_doc
+class JavaMLReader(object):
+ """
+ .. note:: Experimental
+
+ Utility class that can load ML instances through their Scala implementation.
+
+ .. versionadded:: 2.0.0
+ """
+
+ def __init__(self, clazz):
+ self._clazz = clazz
+ self._jread = self._load_java_obj(clazz).read()
+
+ def load(self, path):
+ """Load the ML instance from the input path."""
+ if not isinstance(path, basestring):
+ raise TypeError("path should be a basestring, got type %s" % type(path))
+ java_obj = self._jread.load(path)
+ instance = self._clazz()
+ instance._java_obj = java_obj
+ instance._resetUid(java_obj.uid())
+ instance._transfer_params_from_java()
+ return instance
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for loading."""
+ self._jread.context(sqlContext._ssql_ctx)
+ return self
+
+ @classmethod
+ def _java_loader_class(cls, clazz):
+ """
+ Returns the full class name of the Java ML instance. The default
+ implementation replaces "pyspark" by "org.apache.spark" in
+ the Python full class name.
+ """
+ java_package = clazz.__module__.replace("pyspark", "org.apache.spark")
+ return ".".join([java_package, clazz.__name__])
+
+ @classmethod
+ def _load_java_obj(cls, clazz):
+ """Load the peer Java object of the ML instance."""
+ java_class = cls._java_loader_class(clazz)
+ java_obj = _jvm()
+ for name in java_class.split("."):
+ java_obj = getattr(java_obj, name)
+ return java_obj
+
+
+@inherit_doc
+class MLReadable(object):
+ """
+ .. note:: Experimental
+
+ Mixin for instances that provide JavaMLReader.
+
+ .. versionadded:: 2.0.0
+ """
+
+ @classmethod
+ def read(cls):
+ """Returns an JavaMLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def load(cls, path):
+ """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
+ return cls.read().load(path)
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index dd1d4b076e..d4d48eb215 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -21,21 +21,10 @@ from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer, Model
+from pyspark.ml.util import _jvm
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
-def _jvm():
- """
- Returns the JVM view associated with SparkContext. Must be called
- after SparkContext is initialized.
- """
- jvm = SparkContext._jvm
- if jvm:
- return jvm
- else:
- raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
-
-
@inherit_doc
class JavaWrapper(Params):
"""
@@ -159,15 +148,24 @@ class JavaModel(Model, JavaTransformer):
__metaclass__ = ABCMeta
- def __init__(self, java_model):
+ def __init__(self, java_model=None):
"""
Initialize this instance with a Java model object.
Subclasses should call this constructor, initialize params,
and then call _transformer_params_from_java.
+
+ This instance can be instantiated without specifying java_model,
+ it will be assigned after that, but this scenario only used by
+ :py:class:`JavaMLReader` to load models. This is a bit of a
+ hack, but it is easiest since a proper fix would require
+ MLReader (in pyspark.ml.util) to depend on these wrappers, but
+ these wrappers depend on pyspark.ml.util (both directly and via
+ other ML classes).
"""
super(JavaModel, self).__init__()
- self._java_obj = java_model
- self.uid = java_model.uid()
+ if java_model is not None:
+ self._java_obj = java_model
+ self.uid = java_model.uid()
def copy(self, extra=None):
"""
@@ -182,8 +180,9 @@ class JavaModel(Model, JavaTransformer):
if extra is None:
extra = dict()
that = super(JavaModel, self).copy(extra)
- that._java_obj = self._java_obj.copy(self._empty_java_param_map())
- that._transfer_params_to_java()
+ if self._java_obj is not None:
+ that._java_obj = self._java_obj.copy(self._empty_java_param_map())
+ that._transfer_params_to_java()
return that
def _call_java(self, name, *args):