From 7e3423b9c03c9812d404134c3d204c4cfea87721 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 22 Mar 2016 12:11:23 -0700 Subject: [SPARK-13951][ML][PYTHON] Nested Pipeline persistence Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations. Also: * Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader. * Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests. Author: Joseph K. Bradley Closes #11866 from jkbradley/nested-pipeline-io. Closes #11835 --- python/pyspark/ml/util.py | 89 ++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 73 insertions(+), 16 deletions(-) (limited to 'python/pyspark/ml/util.py') diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 42801c91bb..6703851262 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -74,18 +74,38 @@ class Identifiable(object): @inherit_doc -class JavaMLWriter(object): +class MLWriter(object): """ .. note:: Experimental - Utility class that can save ML instances through their Scala implementation. + Utility class that can save ML instances. .. versionadded:: 2.0.0 """ + def save(self, path): + """Save the ML instance to the input path.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def overwrite(self): + """Overwrites if the output path already exists.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types + """ + def __init__(self, instance): - instance._transfer_params_to_java() - self._jwrite = instance._java_obj.write() + super(JavaMLWriter, self).__init__() + _java_obj = instance._to_java() + self._jwrite = _java_obj.write() def save(self, path): """Save the ML instance to the input path.""" @@ -109,14 +129,14 @@ class MLWritable(object): """ .. note:: Experimental - Mixin for ML instances that provide JavaMLWriter. + Mixin for ML instances that provide :py:class:`MLWriter`. .. versionadded:: 2.0.0 """ def write(self): """Returns an JavaMLWriter instance for this ML instance.""" - return JavaMLWriter(self) + raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self)) def save(self, path): """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" @@ -124,15 +144,41 @@ class MLWritable(object): @inherit_doc -class JavaMLReader(object): +class JavaMLWritable(MLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. + """ + + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + +@inherit_doc +class MLReader(object): """ .. note:: Experimental - Utility class that can load ML instances through their Scala implementation. + Utility class that can load ML instances. .. versionadded:: 2.0.0 """ + def load(self, path): + """Load the ML instance from the input path.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types + """ + def __init__(self, clazz): self._clazz = clazz self._jread = self._load_java_obj(clazz).read() @@ -142,11 +188,10 @@ class JavaMLReader(object): if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) java_obj = self._jread.load(path) - instance = self._clazz() - instance._java_obj = java_obj - instance._resetUid(java_obj.uid()) - instance._transfer_params_from_java() - return instance + if not hasattr(self._clazz, "_from_java"): + raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r" + % self._clazz) + return self._clazz._from_java(java_obj) def context(self, sqlContext): """Sets the SQL context to use for loading.""" @@ -164,7 +209,7 @@ class JavaMLReader(object): if clazz.__name__ in ("Pipeline", "PipelineModel"): # Remove the last package name "pipeline" for Pipeline and PipelineModel. java_package = ".".join(java_package.split(".")[0:-1]) - return ".".join([java_package, clazz.__name__]) + return java_package + "." + clazz.__name__ @classmethod def _load_java_obj(cls, clazz): @@ -181,7 +226,7 @@ class MLReadable(object): """ .. note:: Experimental - Mixin for instances that provide JavaMLReader. + Mixin for instances that provide :py:class:`MLReader`. .. versionadded:: 2.0.0 """ @@ -189,9 +234,21 @@ class MLReadable(object): @classmethod def read(cls): """Returns an JavaMLReader instance for this class.""" - return JavaMLReader(cls) + raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls) @classmethod def load(cls, path): """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" return cls.read().load(path) + + +@inherit_doc +class JavaMLReadable(MLReadable): + """ + (Private) Mixin for instances that provide JavaMLReader. + """ + + @classmethod + def read(cls): + """Returns an JavaMLReader instance for this class.""" + return JavaMLReader(cls) -- cgit v1.2.3