aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-16 13:49:40 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-16 13:49:40 -0700
commitae6c677c8a03174787be99af6238a5e1fbe4e389 (patch)
tree75943410b6cfbe50c66ff199ab6164d24edeef84 /python/pyspark/ml/wrapper.py
parentc4bd57602c0b14188d364bb475631bf473d25082 (diff)
downloadspark-ae6c677c8a03174787be99af6238a5e1fbe4e389.tar.gz
spark-ae6c677c8a03174787be99af6238a5e1fbe4e389.tar.bz2
spark-ae6c677c8a03174787be99af6238a5e1fbe4e389.zip
[SPARK-13038][PYSPARK] Add load/save to pipeline
## What changes were proposed in this pull request? JIRA issue: https://issues.apache.org/jira/browse/SPARK-13038 1. Add load/save to PySpark Pipeline and PipelineModel 2. Add `_transfer_stage_to_java()` and `_transfer_stage_from_java()` for `JavaWrapper`. ## How was this patch tested? Test with doctest. Author: Xusen Yin <yinxusen@gmail.com> Closes #11683 from yinxusen/SPARK-13038-only.
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py29
1 files changed, 28 insertions, 1 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index f8feaa1dfa..0f7b5e9b9e 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -19,8 +19,8 @@ from abc import ABCMeta, abstractmethod
from pyspark import SparkContext
from pyspark.sql import DataFrame
+from pyspark.ml import Estimator, Transformer, Model
from pyspark.ml.param import Params
-from pyspark.ml.pipeline import Estimator, Transformer, Model
from pyspark.ml.util import _jvm
from pyspark.mllib.common import inherit_doc, _java2py, _py2java
@@ -90,6 +90,33 @@ class JavaWrapper(Params):
"""
return _jvm().org.apache.spark.ml.param.ParamMap()
+ def _transfer_stage_to_java(self):
+ self._transfer_params_to_java()
+ return self._java_obj
+
+ @staticmethod
+ def _transfer_stage_from_java(java_stage):
+ def __get_class(clazz):
+ """
+ Loads Python class from its name.
+ """
+ parts = clazz.split('.')
+ module = ".".join(parts[:-1])
+ m = __import__(module)
+ for comp in parts[1:]:
+ m = getattr(m, comp)
+ return m
+ stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
+ # Generate a default new instance from the stage_name class.
+ py_stage = __get_class(stage_name)()
+ assert(isinstance(py_stage, JavaWrapper),
+ "Python side implementation is not supported in the meta-PipelineStage currently.")
+ # Load information from java_stage to the instance.
+ py_stage._java_obj = java_stage
+ py_stage._resetUid(java_stage.uid())
+ py_stage._transfer_params_from_java()
+ return py_stage
+
@inherit_doc
class JavaEstimator(Estimator, JavaWrapper):