aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py37
1 files changed, 28 insertions, 9 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 37dcb23b67..35b0eba926 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -95,12 +95,26 @@ class JavaWrapper(Params):
"""
return _jvm().org.apache.spark.ml.param.ParamMap()
- def _transfer_stage_to_java(self):
+ def _to_java(self):
+ """
+ Transfer this instance's Params to the wrapped Java object, and return the Java object.
+ Used for ML persistence.
+
+ Meta-algorithms such as Pipeline should override this method.
+
+ :return: Java object equivalent to this instance.
+ """
self._transfer_params_to_java()
return self._java_obj
@staticmethod
- def _transfer_stage_from_java(java_stage):
+ def _from_java(java_stage):
+ """
+ Given a Java object, create and return a Python wrapper of it.
+ Used for ML persistence.
+
+ Meta-algorithms such as Pipeline should override this method as a classmethod.
+ """
def __get_class(clazz):
"""
Loads Python class from its name.
@@ -113,13 +127,18 @@ class JavaWrapper(Params):
return m
stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
# Generate a default new instance from the stage_name class.
- py_stage = __get_class(stage_name)()
- assert(isinstance(py_stage, JavaWrapper),
- "Python side implementation is not supported in the meta-PipelineStage currently.")
- # Load information from java_stage to the instance.
- py_stage._java_obj = java_stage
- py_stage._resetUid(java_stage.uid())
- py_stage._transfer_params_from_java()
+ py_type = __get_class(stage_name)
+ if issubclass(py_type, JavaWrapper):
+ # Load information from java_stage to the instance.
+ py_stage = py_type()
+ py_stage._java_obj = java_stage
+ py_stage._resetUid(java_stage.uid())
+ py_stage._transfer_params_from_java()
+ elif hasattr(py_type, "_from_java"):
+ py_stage = py_type._from_java(java_stage)
+ else:
+ raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r"
+ % stage_name)
return py_stage