aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/pipeline.py')
-rw-r--r--python/pyspark/ml/pipeline.py150
1 files changed, 75 insertions, 75 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a1658b0a02..2b5504bc29 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -24,72 +24,31 @@ from pyspark import SparkContext
from pyspark import since
from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
-from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader
+from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.common import inherit_doc
-def _stages_java2py(java_stages):
- """
- Transforms the parameter Python stages from a list of Java stages.
- :param java_stages: An array of Java stages.
- :return: An array of Python stages.
- """
-
- return [JavaWrapper._transfer_stage_from_java(stage) for stage in java_stages]
-
-
-def _stages_py2java(py_stages, cls):
- """
- Transforms the parameter of Python stages to a Java array of Java stages.
- :param py_stages: An array of Python stages.
- :return: A Java array of Java Stages.
- """
-
- for stage in py_stages:
- assert(isinstance(stage, JavaWrapper),
- "Python side implementation is not supported in the meta-PipelineStage currently.")
- gateway = SparkContext._gateway
- java_stages = gateway.new_array(cls, len(py_stages))
- for idx, stage in enumerate(py_stages):
- java_stages[idx] = stage._transfer_stage_to_java()
- return java_stages
-
-
@inherit_doc
-class PipelineMLWriter(JavaMLWriter, JavaWrapper):
+class PipelineMLWriter(JavaMLWriter):
"""
Private Pipeline utility class that can save ML instances through their Scala implementation.
- """
- def __init__(self, instance):
- cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
- self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", instance.uid)
- self._java_obj.setStages(_stages_py2java(instance.getStages(), cls))
- self._jwrite = self._java_obj.write()
+ We can currently use JavaMLWriter, rather than MLWriter, since Pipeline implements _to_java.
+ """
@inherit_doc
class PipelineMLReader(JavaMLReader):
"""
Private utility class that can load Pipeline instances through their Scala implementation.
- """
- def load(self, path):
- """Load the Pipeline instance from the input path."""
- if not isinstance(path, basestring):
- raise TypeError("path should be a basestring, got type %s" % type(path))
-
- java_obj = self._jread.load(path)
- instance = self._clazz()
- instance._resetUid(java_obj.uid())
- instance.setStages(_stages_java2py(java_obj.getStages()))
-
- return instance
+ We can currently use JavaMLReader, rather than MLReader, since Pipeline implements _from_java.
+ """
@inherit_doc
-class Pipeline(Estimator):
+class Pipeline(Estimator, MLReadable, MLWritable):
"""
A simple pipeline, which acts as an estimator. A Pipeline consists
of a sequence of stages, each of which is either an
@@ -206,49 +165,65 @@ class Pipeline(Estimator):
@classmethod
@since("2.0.0")
def read(cls):
- """Returns an JavaMLReader instance for this class."""
+ """Returns an MLReader instance for this class."""
return PipelineMLReader(cls)
@classmethod
- @since("2.0.0")
- def load(cls, path):
- """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
- return cls.read().load(path)
+ def _from_java(cls, java_stage):
+ """
+ Given a Java Pipeline, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ # Create a new instance of this stage.
+ py_stage = cls()
+ # Load information from java_stage to the instance.
+ py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
+ py_stage.setStages(py_stages)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java Pipeline. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
+ java_stages = gateway.new_array(cls, len(self.getStages()))
+ for idx, stage in enumerate(self.getStages()):
+ java_stages[idx] = stage._to_java()
+
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
+ _java_obj.setStages(java_stages)
+
+ return _java_obj
@inherit_doc
-class PipelineModelMLWriter(JavaMLWriter, JavaWrapper):
+class PipelineModelMLWriter(JavaMLWriter):
"""
Private PipelineModel utility class that can save ML instances through their Scala
implementation.
- """
- def __init__(self, instance):
- cls = SparkContext._jvm.org.apache.spark.ml.Transformer
- self._java_obj = self._new_java_obj("org.apache.spark.ml.PipelineModel",
- instance.uid,
- _stages_py2java(instance.stages, cls))
- self._jwrite = self._java_obj.write()
+ We can (currently) use JavaMLWriter, rather than MLWriter, since PipelineModel implements
+ _to_java.
+ """
@inherit_doc
class PipelineModelMLReader(JavaMLReader):
"""
Private utility class that can load PipelineModel instances through their Scala implementation.
- """
- def load(self, path):
- """Load the PipelineModel instance from the input path."""
- if not isinstance(path, basestring):
- raise TypeError("path should be a basestring, got type %s" % type(path))
- java_obj = self._jread.load(path)
- instance = self._clazz(_stages_java2py(java_obj.stages()))
- instance._resetUid(java_obj.uid())
- return instance
+ We can currently use JavaMLReader, rather than MLReader, since PipelineModel implements
+ _from_java.
+ """
@inherit_doc
-class PipelineModel(Model):
+class PipelineModel(Model, MLReadable, MLWritable):
"""
Represents a compiled pipeline with transformers and fitted models.
@@ -294,7 +269,32 @@ class PipelineModel(Model):
return PipelineModelMLReader(cls)
@classmethod
- @since("2.0.0")
- def load(cls, path):
- """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
- return cls.read().load(path)
+ def _from_java(cls, java_stage):
+ """
+ Given a Java PipelineModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ # Load information from java_stage to the instance.
+ py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
+ # Create a new instance of this stage.
+ py_stage = cls(py_stages)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java PipelineModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.Transformer
+ java_stages = gateway.new_array(cls, len(self.stages))
+ for idx, stage in enumerate(self.stages):
+ java_stages[idx] = stage._to_java()
+
+ _java_obj =\
+ JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
+
+ return _java_obj