aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/pipeline.py
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-22 12:11:23 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-22 12:11:37 -0700
commit7e3423b9c03c9812d404134c3d204c4cfea87721 (patch)
treeb922610e318774c1db7da6549ee0932b21fe3090 /python/pyspark/ml/pipeline.py
parent297c20226d3330309c9165d789749458f8f4ab8e (diff)
downloadspark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.gz
spark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.bz2
spark-7e3423b9c03c9812d404134c3d204c4cfea87721.zip
[SPARK-13951][ML][PYTHON] Nested Pipeline persistence
Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations. Also: * Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader. * Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11866 from jkbradley/nested-pipeline-io. Closes #11835
Diffstat (limited to 'python/pyspark/ml/pipeline.py')
-rw-r--r--python/pyspark/ml/pipeline.py150
1 files changed, 75 insertions, 75 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a1658b0a02..2b5504bc29 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -24,72 +24,31 @@ from pyspark import SparkContext
from pyspark import since
from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
-from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader
+from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.common import inherit_doc
-def _stages_java2py(java_stages):
- """
- Transforms the parameter Python stages from a list of Java stages.
- :param java_stages: An array of Java stages.
- :return: An array of Python stages.
- """
-
- return [JavaWrapper._transfer_stage_from_java(stage) for stage in java_stages]
-
-
-def _stages_py2java(py_stages, cls):
- """
- Transforms the parameter of Python stages to a Java array of Java stages.
- :param py_stages: An array of Python stages.
- :return: A Java array of Java Stages.
- """
-
- for stage in py_stages:
- assert(isinstance(stage, JavaWrapper),
- "Python side implementation is not supported in the meta-PipelineStage currently.")
- gateway = SparkContext._gateway
- java_stages = gateway.new_array(cls, len(py_stages))
- for idx, stage in enumerate(py_stages):
- java_stages[idx] = stage._transfer_stage_to_java()
- return java_stages
-
-
@inherit_doc
-class PipelineMLWriter(JavaMLWriter, JavaWrapper):
+class PipelineMLWriter(JavaMLWriter):
"""
Private Pipeline utility class that can save ML instances through their Scala implementation.
- """
- def __init__(self, instance):
- cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
- self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", instance.uid)
- self._java_obj.setStages(_stages_py2java(instance.getStages(), cls))
- self._jwrite = self._java_obj.write()
+ We can currently use JavaMLWriter, rather than MLWriter, since Pipeline implements _to_java.
+ """
@inherit_doc
class PipelineMLReader(JavaMLReader):
"""
Private utility class that can load Pipeline instances through their Scala implementation.
- """
- def load(self, path):
- """Load the Pipeline instance from the input path."""
- if not isinstance(path, basestring):
- raise TypeError("path should be a basestring, got type %s" % type(path))
-
- java_obj = self._jread.load(path)
- instance = self._clazz()
- instance._resetUid(java_obj.uid())
- instance.setStages(_stages_java2py(java_obj.getStages()))
-
- return instance
+ We can currently use JavaMLReader, rather than MLReader, since Pipeline implements _from_java.
+ """
@inherit_doc
-class Pipeline(Estimator):
+class Pipeline(Estimator, MLReadable, MLWritable):
"""
A simple pipeline, which acts as an estimator. A Pipeline consists
of a sequence of stages, each of which is either an
@@ -206,49 +165,65 @@ class Pipeline(Estimator):
@classmethod
@since("2.0.0")
def read(cls):
- """Returns an JavaMLReader instance for this class."""
+ """Returns an MLReader instance for this class."""
return PipelineMLReader(cls)
@classmethod
- @since("2.0.0")
- def load(cls, path):
- """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
- return cls.read().load(path)
+ def _from_java(cls, java_stage):
+ """
+ Given a Java Pipeline, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ # Create a new instance of this stage.
+ py_stage = cls()
+ # Load information from java_stage to the instance.
+ py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
+ py_stage.setStages(py_stages)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java Pipeline. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
+ java_stages = gateway.new_array(cls, len(self.getStages()))
+ for idx, stage in enumerate(self.getStages()):
+ java_stages[idx] = stage._to_java()
+
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
+ _java_obj.setStages(java_stages)
+
+ return _java_obj
@inherit_doc
-class PipelineModelMLWriter(JavaMLWriter, JavaWrapper):
+class PipelineModelMLWriter(JavaMLWriter):
"""
Private PipelineModel utility class that can save ML instances through their Scala
implementation.
- """
- def __init__(self, instance):
- cls = SparkContext._jvm.org.apache.spark.ml.Transformer
- self._java_obj = self._new_java_obj("org.apache.spark.ml.PipelineModel",
- instance.uid,
- _stages_py2java(instance.stages, cls))
- self._jwrite = self._java_obj.write()
+ We can (currently) use JavaMLWriter, rather than MLWriter, since PipelineModel implements
+ _to_java.
+ """
@inherit_doc
class PipelineModelMLReader(JavaMLReader):
"""
Private utility class that can load PipelineModel instances through their Scala implementation.
- """
- def load(self, path):
- """Load the PipelineModel instance from the input path."""
- if not isinstance(path, basestring):
- raise TypeError("path should be a basestring, got type %s" % type(path))
- java_obj = self._jread.load(path)
- instance = self._clazz(_stages_java2py(java_obj.stages()))
- instance._resetUid(java_obj.uid())
- return instance
+ We can currently use JavaMLReader, rather than MLReader, since PipelineModel implements
+ _from_java.
+ """
@inherit_doc
-class PipelineModel(Model):
+class PipelineModel(Model, MLReadable, MLWritable):
"""
Represents a compiled pipeline with transformers and fitted models.
@@ -294,7 +269,32 @@ class PipelineModel(Model):
return PipelineModelMLReader(cls)
@classmethod
- @since("2.0.0")
- def load(cls, path):
- """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
- return cls.read().load(path)
+ def _from_java(cls, java_stage):
+ """
+ Given a Java PipelineModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ # Load information from java_stage to the instance.
+ py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
+ # Create a new instance of this stage.
+ py_stage = cls(py_stages)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java PipelineModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.Transformer
+ java_stages = gateway.new_array(cls, len(self.stages))
+ for idx, stage in enumerate(self.stages):
+ java_stages[idx] = stage._to_java()
+
+ _java_obj =\
+ JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
+
+ return _java_obj