aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-22 12:11:23 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-22 12:11:37 -0700
commit7e3423b9c03c9812d404134c3d204c4cfea87721 (patch)
treeb922610e318774c1db7da6549ee0932b21fe3090 /python/pyspark/ml/wrapper.py
parent297c20226d3330309c9165d789749458f8f4ab8e (diff)
downloadspark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.gz
spark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.bz2
spark-7e3423b9c03c9812d404134c3d204c4cfea87721.zip
[SPARK-13951][ML][PYTHON] Nested Pipeline persistence
Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations. Also: * Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader. * Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11866 from jkbradley/nested-pipeline-io. Closes #11835
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py37
1 files changed, 28 insertions, 9 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 37dcb23b67..35b0eba926 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -95,12 +95,26 @@ class JavaWrapper(Params):
"""
return _jvm().org.apache.spark.ml.param.ParamMap()
- def _transfer_stage_to_java(self):
+ def _to_java(self):
+ """
+ Transfer this instance's Params to the wrapped Java object, and return the Java object.
+ Used for ML persistence.
+
+ Meta-algorithms such as Pipeline should override this method.
+
+ :return: Java object equivalent to this instance.
+ """
self._transfer_params_to_java()
return self._java_obj
@staticmethod
- def _transfer_stage_from_java(java_stage):
+ def _from_java(java_stage):
+ """
+ Given a Java object, create and return a Python wrapper of it.
+ Used for ML persistence.
+
+ Meta-algorithms such as Pipeline should override this method as a classmethod.
+ """
def __get_class(clazz):
"""
Loads Python class from its name.
@@ -113,13 +127,18 @@ class JavaWrapper(Params):
return m
stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
# Generate a default new instance from the stage_name class.
- py_stage = __get_class(stage_name)()
- assert(isinstance(py_stage, JavaWrapper),
- "Python side implementation is not supported in the meta-PipelineStage currently.")
- # Load information from java_stage to the instance.
- py_stage._java_obj = java_stage
- py_stage._resetUid(java_stage.uid())
- py_stage._transfer_params_from_java()
+ py_type = __get_class(stage_name)
+ if issubclass(py_type, JavaWrapper):
+ # Load information from java_stage to the instance.
+ py_stage = py_type()
+ py_stage._java_obj = java_stage
+ py_stage._resetUid(java_stage.uid())
+ py_stage._transfer_params_from_java()
+ elif hasattr(py_type, "_from_java"):
+ py_stage = py_type._from_java(java_stage)
+ else:
+ raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r"
+ % stage_name)
return py_stage