diff options
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r-- | python/pyspark/ml/wrapper.py | 37 |
1 files changed, 28 insertions, 9 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 37dcb23b67..35b0eba926 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -95,12 +95,26 @@ class JavaWrapper(Params): """ return _jvm().org.apache.spark.ml.param.ParamMap() - def _transfer_stage_to_java(self): + def _to_java(self): + """ + Transfer this instance's Params to the wrapped Java object, and return the Java object. + Used for ML persistence. + + Meta-algorithms such as Pipeline should override this method. + + :return: Java object equivalent to this instance. + """ self._transfer_params_to_java() return self._java_obj @staticmethod - def _transfer_stage_from_java(java_stage): + def _from_java(java_stage): + """ + Given a Java object, create and return a Python wrapper of it. + Used for ML persistence. + + Meta-algorithms such as Pipeline should override this method as a classmethod. + """ def __get_class(clazz): """ Loads Python class from its name. @@ -113,13 +127,18 @@ class JavaWrapper(Params): return m stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. - py_stage = __get_class(stage_name)() - assert(isinstance(py_stage, JavaWrapper), - "Python side implementation is not supported in the meta-PipelineStage currently.") - # Load information from java_stage to the instance. - py_stage._java_obj = java_stage - py_stage._resetUid(java_stage.uid()) - py_stage._transfer_params_from_java() + py_type = __get_class(stage_name) + if issubclass(py_type, JavaWrapper): + # Load information from java_stage to the instance. + py_stage = py_type() + py_stage._java_obj = java_stage + py_stage._resetUid(java_stage.uid()) + py_stage._transfer_params_from_java() + elif hasattr(py_type, "_from_java"): + py_stage = py_type._from_java(java_stage) + else: + raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r" + % stage_name) return py_stage |