aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/pipeline.py
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-16 13:49:40 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-16 13:49:40 -0700
commitae6c677c8a03174787be99af6238a5e1fbe4e389 (patch)
tree75943410b6cfbe50c66ff199ab6164d24edeef84 /python/pyspark/ml/pipeline.py
parentc4bd57602c0b14188d364bb475631bf473d25082 (diff)
downloadspark-ae6c677c8a03174787be99af6238a5e1fbe4e389.tar.gz
spark-ae6c677c8a03174787be99af6238a5e1fbe4e389.tar.bz2
spark-ae6c677c8a03174787be99af6238a5e1fbe4e389.zip
[SPARK-13038][PYSPARK] Add load/save to pipeline
## What changes were proposed in this pull request? JIRA issue: https://issues.apache.org/jira/browse/SPARK-13038 1. Add load/save to PySpark Pipeline and PipelineModel 2. Add `_transfer_stage_to_java()` and `_transfer_stage_from_java()` for `JavaWrapper`. ## How was this patch tested? Test with doctest. Author: Xusen Yin <yinxusen@gmail.com> Closes #11683 from yinxusen/SPARK-13038-only.
Diffstat (limited to 'python/pyspark/ml/pipeline.py')
-rw-r--r--python/pyspark/ml/pipeline.py208
1 files changed, 122 insertions, 86 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 661074ca96..a1658b0a02 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -15,116 +15,77 @@
# limitations under the License.
#
-from abc import ABCMeta, abstractmethod
+import sys
+if sys.version > '3':
+ basestring = str
+
+from pyspark import SparkContext
from pyspark import since
+from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
-from pyspark.ml.util import keyword_only
+from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader
+from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.common import inherit_doc
-@inherit_doc
-class Estimator(Params):
+def _stages_java2py(java_stages):
"""
- Abstract class for estimators that fit models to data.
-
- .. versionadded:: 1.3.0
+ Transforms the parameter Python stages from a list of Java stages.
+ :param java_stages: An array of Java stages.
+ :return: An array of Python stages.
"""
- __metaclass__ = ABCMeta
+ return [JavaWrapper._transfer_stage_from_java(stage) for stage in java_stages]
- @abstractmethod
- def _fit(self, dataset):
- """
- Fits a model to the input dataset. This is called by the
- default implementation of fit.
- :param dataset: input dataset, which is an instance of
- :py:class:`pyspark.sql.DataFrame`
- :returns: fitted model
- """
- raise NotImplementedError()
+def _stages_py2java(py_stages, cls):
+ """
+ Transforms the parameter of Python stages to a Java array of Java stages.
+ :param py_stages: An array of Python stages.
+ :return: A Java array of Java Stages.
+ """
- @since("1.3.0")
- def fit(self, dataset, params=None):
- """
- Fits a model to the input dataset with optional parameters.
-
- :param dataset: input dataset, which is an instance of
- :py:class:`pyspark.sql.DataFrame`
- :param params: an optional param map that overrides embedded
- params. If a list/tuple of param maps is given,
- this calls fit on each param map and returns a
- list of models.
- :returns: fitted model(s)
- """
- if params is None:
- params = dict()
- if isinstance(params, (list, tuple)):
- return [self.fit(dataset, paramMap) for paramMap in params]
- elif isinstance(params, dict):
- if params:
- return self.copy(params)._fit(dataset)
- else:
- return self._fit(dataset)
- else:
- raise ValueError("Params must be either a param map or a list/tuple of param maps, "
- "but got %s." % type(params))
+ for stage in py_stages:
+ assert(isinstance(stage, JavaWrapper),
+ "Python side implementation is not supported in the meta-PipelineStage currently.")
+ gateway = SparkContext._gateway
+ java_stages = gateway.new_array(cls, len(py_stages))
+ for idx, stage in enumerate(py_stages):
+ java_stages[idx] = stage._transfer_stage_to_java()
+ return java_stages
@inherit_doc
-class Transformer(Params):
+class PipelineMLWriter(JavaMLWriter, JavaWrapper):
"""
- Abstract class for transformers that transform one dataset into
- another.
-
- .. versionadded:: 1.3.0
+ Private Pipeline utility class that can save ML instances through their Scala implementation.
"""
- __metaclass__ = ABCMeta
-
- @abstractmethod
- def _transform(self, dataset):
- """
- Transforms the input dataset.
-
- :param dataset: input dataset, which is an instance of
- :py:class:`pyspark.sql.DataFrame`
- :returns: transformed dataset
- """
- raise NotImplementedError()
-
- @since("1.3.0")
- def transform(self, dataset, params=None):
- """
- Transforms the input dataset with optional parameters.
-
- :param dataset: input dataset, which is an instance of
- :py:class:`pyspark.sql.DataFrame`
- :param params: an optional param map that overrides embedded
- params.
- :returns: transformed dataset
- """
- if params is None:
- params = dict()
- if isinstance(params, dict):
- if params:
- return self.copy(params,)._transform(dataset)
- else:
- return self._transform(dataset)
- else:
- raise ValueError("Params must be either a param map but got %s." % type(params))
+ def __init__(self, instance):
+ cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", instance.uid)
+ self._java_obj.setStages(_stages_py2java(instance.getStages(), cls))
+ self._jwrite = self._java_obj.write()
@inherit_doc
-class Model(Transformer):
+class PipelineMLReader(JavaMLReader):
"""
- Abstract class for models that are fitted by estimators.
-
- .. versionadded:: 1.4.0
+ Private utility class that can load Pipeline instances through their Scala implementation.
"""
- __metaclass__ = ABCMeta
+ def load(self, path):
+ """Load the Pipeline instance from the input path."""
+ if not isinstance(path, basestring):
+ raise TypeError("path should be a basestring, got type %s" % type(path))
+
+ java_obj = self._jread.load(path)
+ instance = self._clazz()
+ instance._resetUid(java_obj.uid())
+ instance.setStages(_stages_java2py(java_obj.getStages()))
+
+ return instance
@inherit_doc
@@ -232,6 +193,59 @@ class Pipeline(Estimator):
stages = [stage.copy(extra) for stage in that.getStages()]
return that.setStages(stages)
+ @since("2.0.0")
+ def write(self):
+ """Returns an JavaMLWriter instance for this ML instance."""
+ return PipelineMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an JavaMLReader instance for this class."""
+ return PipelineMLReader(cls)
+
+ @classmethod
+ @since("2.0.0")
+ def load(cls, path):
+ """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
+ return cls.read().load(path)
+
+
+@inherit_doc
+class PipelineModelMLWriter(JavaMLWriter, JavaWrapper):
+ """
+ Private PipelineModel utility class that can save ML instances through their Scala
+ implementation.
+ """
+
+ def __init__(self, instance):
+ cls = SparkContext._jvm.org.apache.spark.ml.Transformer
+ self._java_obj = self._new_java_obj("org.apache.spark.ml.PipelineModel",
+ instance.uid,
+ _stages_py2java(instance.stages, cls))
+ self._jwrite = self._java_obj.write()
+
+
+@inherit_doc
+class PipelineModelMLReader(JavaMLReader):
+ """
+ Private utility class that can load PipelineModel instances through their Scala implementation.
+ """
+
+ def load(self, path):
+ """Load the PipelineModel instance from the input path."""
+ if not isinstance(path, basestring):
+ raise TypeError("path should be a basestring, got type %s" % type(path))
+ java_obj = self._jread.load(path)
+ instance = self._clazz(_stages_java2py(java_obj.stages()))
+ instance._resetUid(java_obj.uid())
+ return instance
+
@inherit_doc
class PipelineModel(Model):
@@ -262,3 +276,25 @@ class PipelineModel(Model):
extra = dict()
stages = [stage.copy(extra) for stage in self.stages]
return PipelineModel(stages)
+
+ @since("2.0.0")
+ def write(self):
+ """Returns an JavaMLWriter instance for this ML instance."""
+ return PipelineModelMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an JavaMLReader instance for this class."""
+ return PipelineModelMLReader(cls)
+
+ @classmethod
+ @since("2.0.0")
+ def load(cls, path):
+ """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
+ return cls.read().load(path)