diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-01-29 09:22:24 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-29 09:22:24 -0800 |
commit | e51b6eaa9e9c007e194d858195291b2b9fb27322 (patch) | |
tree | b6af90c439154fe7514fd32e47a56a693ffd745a /python/pyspark/ml/wrapper.py | |
parent | 55561e7693dd2a5bf3c7f8026c725421801fd0ec (diff) | |
download | spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.gz spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.bz2 spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.zip |
[SPARK-13032][ML][PYSPARK] PySpark support model export/import and take LinearRegression as example
* Implement ```MLWriter/MLWritable/MLReader/MLReadable``` for PySpark.
* Making ```LinearRegression``` to support ```save/load``` as example. After this merged, the work for other transformers/estimators will be easy, then we can list and distribute the tasks to the community.
cc mengxr jkbradley
Author: Yanbo Liang <ybliang8@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #10469 from yanboliang/spark-11939.
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r-- | python/pyspark/ml/wrapper.py | 33 |
1 files changed, 16 insertions, 17 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index dd1d4b076e..d4d48eb215 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -21,21 +21,10 @@ from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params from pyspark.ml.pipeline import Estimator, Transformer, Model +from pyspark.ml.util import _jvm from pyspark.mllib.common import inherit_doc, _java2py, _py2java -def _jvm(): - """ - Returns the JVM view associated with SparkContext. Must be called - after SparkContext is initialized. - """ - jvm = SparkContext._jvm - if jvm: - return jvm - else: - raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") - - @inherit_doc class JavaWrapper(Params): """ @@ -159,15 +148,24 @@ class JavaModel(Model, JavaTransformer): __metaclass__ = ABCMeta - def __init__(self, java_model): + def __init__(self, java_model=None): """ Initialize this instance with a Java model object. Subclasses should call this constructor, initialize params, and then call _transformer_params_from_java. + + This instance can be instantiated without specifying java_model, + it will be assigned after that, but this scenario only used by + :py:class:`JavaMLReader` to load models. This is a bit of a + hack, but it is easiest since a proper fix would require + MLReader (in pyspark.ml.util) to depend on these wrappers, but + these wrappers depend on pyspark.ml.util (both directly and via + other ML classes). """ super(JavaModel, self).__init__() - self._java_obj = java_model - self.uid = java_model.uid() + if java_model is not None: + self._java_obj = java_model + self.uid = java_model.uid() def copy(self, extra=None): """ @@ -182,8 +180,9 @@ class JavaModel(Model, JavaTransformer): if extra is None: extra = dict() that = super(JavaModel, self).copy(extra) - that._java_obj = self._java_obj.copy(self._empty_java_param_map()) - that._transfer_params_to_java() + if self._java_obj is not None: + that._java_obj = self._java_obj.copy(self._empty_java_param_map()) + that._transfer_params_to_java() return that def _call_java(self, name, *args): |