diff options
Diffstat (limited to 'python/pyspark/ml/pipeline.py')
-rw-r--r-- | python/pyspark/ml/pipeline.py | 150 |
1 files changed, 75 insertions, 75 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a1658b0a02..2b5504bc29 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -24,72 +24,31 @@ from pyspark import SparkContext from pyspark import since from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader +from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.common import inherit_doc -def _stages_java2py(java_stages): - """ - Transforms the parameter Python stages from a list of Java stages. - :param java_stages: An array of Java stages. - :return: An array of Python stages. - """ - - return [JavaWrapper._transfer_stage_from_java(stage) for stage in java_stages] - - -def _stages_py2java(py_stages, cls): - """ - Transforms the parameter of Python stages to a Java array of Java stages. - :param py_stages: An array of Python stages. - :return: A Java array of Java Stages. - """ - - for stage in py_stages: - assert(isinstance(stage, JavaWrapper), - "Python side implementation is not supported in the meta-PipelineStage currently.") - gateway = SparkContext._gateway - java_stages = gateway.new_array(cls, len(py_stages)) - for idx, stage in enumerate(py_stages): - java_stages[idx] = stage._transfer_stage_to_java() - return java_stages - - @inherit_doc -class PipelineMLWriter(JavaMLWriter, JavaWrapper): +class PipelineMLWriter(JavaMLWriter): """ Private Pipeline utility class that can save ML instances through their Scala implementation. - """ - def __init__(self, instance): - cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage - self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", instance.uid) - self._java_obj.setStages(_stages_py2java(instance.getStages(), cls)) - self._jwrite = self._java_obj.write() + We can currently use JavaMLWriter, rather than MLWriter, since Pipeline implements _to_java. + """ @inherit_doc class PipelineMLReader(JavaMLReader): """ Private utility class that can load Pipeline instances through their Scala implementation. - """ - def load(self, path): - """Load the Pipeline instance from the input path.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - - java_obj = self._jread.load(path) - instance = self._clazz() - instance._resetUid(java_obj.uid()) - instance.setStages(_stages_java2py(java_obj.getStages())) - - return instance + We can currently use JavaMLReader, rather than MLReader, since Pipeline implements _from_java. + """ @inherit_doc -class Pipeline(Estimator): +class Pipeline(Estimator, MLReadable, MLWritable): """ A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each of which is either an @@ -206,49 +165,65 @@ class Pipeline(Estimator): @classmethod @since("2.0.0") def read(cls): - """Returns an JavaMLReader instance for this class.""" + """Returns an MLReader instance for this class.""" return PipelineMLReader(cls) @classmethod - @since("2.0.0") - def load(cls, path): - """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" - return cls.read().load(path) + def _from_java(cls, java_stage): + """ + Given a Java Pipeline, create and return a Python wrapper of it. + Used for ML persistence. + """ + # Create a new instance of this stage. + py_stage = cls() + # Load information from java_stage to the instance. + py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()] + py_stage.setStages(py_stages) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java Pipeline. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage + java_stages = gateway.new_array(cls, len(self.getStages())) + for idx, stage in enumerate(self.getStages()): + java_stages[idx] = stage._to_java() + + _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj.setStages(java_stages) + + return _java_obj @inherit_doc -class PipelineModelMLWriter(JavaMLWriter, JavaWrapper): +class PipelineModelMLWriter(JavaMLWriter): """ Private PipelineModel utility class that can save ML instances through their Scala implementation. - """ - def __init__(self, instance): - cls = SparkContext._jvm.org.apache.spark.ml.Transformer - self._java_obj = self._new_java_obj("org.apache.spark.ml.PipelineModel", - instance.uid, - _stages_py2java(instance.stages, cls)) - self._jwrite = self._java_obj.write() + We can (currently) use JavaMLWriter, rather than MLWriter, since PipelineModel implements + _to_java. + """ @inherit_doc class PipelineModelMLReader(JavaMLReader): """ Private utility class that can load PipelineModel instances through their Scala implementation. - """ - def load(self, path): - """Load the PipelineModel instance from the input path.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - java_obj = self._jread.load(path) - instance = self._clazz(_stages_java2py(java_obj.stages())) - instance._resetUid(java_obj.uid()) - return instance + We can currently use JavaMLReader, rather than MLReader, since PipelineModel implements + _from_java. + """ @inherit_doc -class PipelineModel(Model): +class PipelineModel(Model, MLReadable, MLWritable): """ Represents a compiled pipeline with transformers and fitted models. @@ -294,7 +269,32 @@ class PipelineModel(Model): return PipelineModelMLReader(cls) @classmethod - @since("2.0.0") - def load(cls, path): - """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" - return cls.read().load(path) + def _from_java(cls, java_stage): + """ + Given a Java PipelineModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + # Load information from java_stage to the instance. + py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()] + # Create a new instance of this stage. + py_stage = cls(py_stages) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java PipelineModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.Transformer + java_stages = gateway.new_array(cls, len(self.stages)) + for idx, stage in enumerate(self.stages): + java_stages[idx] = stage._to_java() + + _java_obj =\ + JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) + + return _java_obj |