aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/util.py
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-22 12:11:23 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-22 12:11:37 -0700
commit7e3423b9c03c9812d404134c3d204c4cfea87721 (patch)
treeb922610e318774c1db7da6549ee0932b21fe3090 /python/pyspark/ml/util.py
parent297c20226d3330309c9165d789749458f8f4ab8e (diff)
downloadspark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.gz
spark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.bz2
spark-7e3423b9c03c9812d404134c3d204c4cfea87721.zip
[SPARK-13951][ML][PYTHON] Nested Pipeline persistence
Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations. Also: * Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader. * Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11866 from jkbradley/nested-pipeline-io. Closes #11835
Diffstat (limited to 'python/pyspark/ml/util.py')
-rw-r--r--python/pyspark/ml/util.py89
1 files changed, 73 insertions, 16 deletions
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 42801c91bb..6703851262 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -74,18 +74,38 @@ class Identifiable(object):
@inherit_doc
-class JavaMLWriter(object):
+class MLWriter(object):
"""
.. note:: Experimental
- Utility class that can save ML instances through their Scala implementation.
+ Utility class that can save ML instances.
.. versionadded:: 2.0.0
"""
+ def save(self, path):
+ """Save the ML instance to the input path."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+ def overwrite(self):
+ """Overwrites if the output path already exists."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for saving."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+
+@inherit_doc
+class JavaMLWriter(MLWriter):
+ """
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types
+ """
+
def __init__(self, instance):
- instance._transfer_params_to_java()
- self._jwrite = instance._java_obj.write()
+ super(JavaMLWriter, self).__init__()
+ _java_obj = instance._to_java()
+ self._jwrite = _java_obj.write()
def save(self, path):
"""Save the ML instance to the input path."""
@@ -109,14 +129,14 @@ class MLWritable(object):
"""
.. note:: Experimental
- Mixin for ML instances that provide JavaMLWriter.
+ Mixin for ML instances that provide :py:class:`MLWriter`.
.. versionadded:: 2.0.0
"""
def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
- return JavaMLWriter(self)
+ raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
def save(self, path):
"""Save this ML instance to the given path, a shortcut of `write().save(path)`."""
@@ -124,15 +144,41 @@ class MLWritable(object):
@inherit_doc
-class JavaMLReader(object):
+class JavaMLWritable(MLWritable):
+ """
+ (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
+ """
+
+ def write(self):
+ """Returns an JavaMLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+
+@inherit_doc
+class MLReader(object):
"""
.. note:: Experimental
- Utility class that can load ML instances through their Scala implementation.
+ Utility class that can load ML instances.
.. versionadded:: 2.0.0
"""
+ def load(self, path):
+ """Load the ML instance from the input path."""
+ raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for loading."""
+ raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
+
+
+@inherit_doc
+class JavaMLReader(MLReader):
+ """
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types
+ """
+
def __init__(self, clazz):
self._clazz = clazz
self._jread = self._load_java_obj(clazz).read()
@@ -142,11 +188,10 @@ class JavaMLReader(object):
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
java_obj = self._jread.load(path)
- instance = self._clazz()
- instance._java_obj = java_obj
- instance._resetUid(java_obj.uid())
- instance._transfer_params_from_java()
- return instance
+ if not hasattr(self._clazz, "_from_java"):
+ raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"
+ % self._clazz)
+ return self._clazz._from_java(java_obj)
def context(self, sqlContext):
"""Sets the SQL context to use for loading."""
@@ -164,7 +209,7 @@ class JavaMLReader(object):
if clazz.__name__ in ("Pipeline", "PipelineModel"):
# Remove the last package name "pipeline" for Pipeline and PipelineModel.
java_package = ".".join(java_package.split(".")[0:-1])
- return ".".join([java_package, clazz.__name__])
+ return java_package + "." + clazz.__name__
@classmethod
def _load_java_obj(cls, clazz):
@@ -181,7 +226,7 @@ class MLReadable(object):
"""
.. note:: Experimental
- Mixin for instances that provide JavaMLReader.
+ Mixin for instances that provide :py:class:`MLReader`.
.. versionadded:: 2.0.0
"""
@@ -189,9 +234,21 @@ class MLReadable(object):
@classmethod
def read(cls):
"""Returns an JavaMLReader instance for this class."""
- return JavaMLReader(cls)
+ raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls)
@classmethod
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
return cls.read().load(path)
+
+
+@inherit_doc
+class JavaMLReadable(MLReadable):
+ """
+ (Private) Mixin for instances that provide JavaMLReader.
+ """
+
+ @classmethod
+ def read(cls):
+ """Returns an JavaMLReader instance for this class."""
+ return JavaMLReader(cls)