aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-29 09:22:24 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-29 09:22:24 -0800
commite51b6eaa9e9c007e194d858195291b2b9fb27322 (patch)
treeb6af90c439154fe7514fd32e47a56a693ffd745a /python/pyspark/ml/wrapper.py
parent55561e7693dd2a5bf3c7f8026c725421801fd0ec (diff)
downloadspark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.gz
spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.tar.bz2
spark-e51b6eaa9e9c007e194d858195291b2b9fb27322.zip
[SPARK-13032][ML][PYSPARK] PySpark support model export/import and take LinearRegression as example
* Implement ```MLWriter/MLWritable/MLReader/MLReadable``` for PySpark. * Making ```LinearRegression``` to support ```save/load``` as example. After this merged, the work for other transformers/estimators will be easy, then we can list and distribute the tasks to the community. cc mengxr jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #10469 from yanboliang/spark-11939.
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py33
1 files changed, 16 insertions, 17 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index dd1d4b076e..d4d48eb215 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -21,21 +21,10 @@ from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer, Model
+from pyspark.ml.util import _jvm
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
-def _jvm():
- """
- Returns the JVM view associated with SparkContext. Must be called
- after SparkContext is initialized.
- """
- jvm = SparkContext._jvm
- if jvm:
- return jvm
- else:
- raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?")
-
-
@inherit_doc
class JavaWrapper(Params):
"""
@@ -159,15 +148,24 @@ class JavaModel(Model, JavaTransformer):
__metaclass__ = ABCMeta
- def __init__(self, java_model):
+ def __init__(self, java_model=None):
"""
Initialize this instance with a Java model object.
Subclasses should call this constructor, initialize params,
and then call _transformer_params_from_java.
+
+ This instance can be instantiated without specifying java_model,
+ it will be assigned after that, but this scenario only used by
+ :py:class:`JavaMLReader` to load models. This is a bit of a
+ hack, but it is easiest since a proper fix would require
+ MLReader (in pyspark.ml.util) to depend on these wrappers, but
+ these wrappers depend on pyspark.ml.util (both directly and via
+ other ML classes).
"""
super(JavaModel, self).__init__()
- self._java_obj = java_model
- self.uid = java_model.uid()
+ if java_model is not None:
+ self._java_obj = java_model
+ self.uid = java_model.uid()
def copy(self, extra=None):
"""
@@ -182,8 +180,9 @@ class JavaModel(Model, JavaTransformer):
if extra is None:
extra = dict()
that = super(JavaModel, self).copy(extra)
- that._java_obj = self._java_obj.copy(self._empty_java_param_map())
- that._transfer_params_to_java()
+ if self._java_obj is not None:
+ that._java_obj = self._java_obj.copy(self._empty_java_param_map())
+ that._transfer_params_to_java()
return that
def _call_java(self, name, *args):