aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/util.py')
-rw-r--r--python/pyspark/ml/util.py89
1 files changed, 73 insertions, 16 deletions
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 42801c91bb..6703851262 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -74,18 +74,38 @@ class Identifiable(object):
@inherit_doc
-class JavaMLWriter(object):
+class MLWriter(object):
"""
.. note:: Experimental
- Utility class that can save ML instances through their Scala implementation.
+ Utility class that can save ML instances.
.. versionadded:: 2.0.0
"""
+ def save(self, path):
+ """Save the ML instance to the input path."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+ def overwrite(self):
+ """Overwrites if the output path already exists."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for saving."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+
+@inherit_doc
+class JavaMLWriter(MLWriter):
+ """
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types
+ """
+
def __init__(self, instance):
- instance._transfer_params_to_java()
- self._jwrite = instance._java_obj.write()
+ super(JavaMLWriter, self).__init__()
+ _java_obj = instance._to_java()
+ self._jwrite = _java_obj.write()
def save(self, path):
"""Save the ML instance to the input path."""
@@ -109,14 +129,14 @@ class MLWritable(object):
"""
.. note:: Experimental
- Mixin for ML instances that provide JavaMLWriter.
+ Mixin for ML instances that provide :py:class:`MLWriter`.
.. versionadded:: 2.0.0
"""
def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
- return JavaMLWriter(self)
+ raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
def save(self, path):
"""Save this ML instance to the given path, a shortcut of `write().save(path)`."""
@@ -124,15 +144,41 @@ class MLWritable(object):
@inherit_doc
-class JavaMLReader(object):
+class JavaMLWritable(MLWritable):
+ """
+ (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
+ """
+
+ def write(self):
+ """Returns an JavaMLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+
+@inherit_doc
+class MLReader(object):
"""
.. note:: Experimental
- Utility class that can load ML instances through their Scala implementation.
+ Utility class that can load ML instances.
.. versionadded:: 2.0.0
"""
+ def load(self, path):
+ """Load the ML instance from the input path."""
+ raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for loading."""
+ raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
+
+
+@inherit_doc
+class JavaMLReader(MLReader):
+ """
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types
+ """
+
def __init__(self, clazz):
self._clazz = clazz
self._jread = self._load_java_obj(clazz).read()
@@ -142,11 +188,10 @@ class JavaMLReader(object):
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
java_obj = self._jread.load(path)
- instance = self._clazz()
- instance._java_obj = java_obj
- instance._resetUid(java_obj.uid())
- instance._transfer_params_from_java()
- return instance
+ if not hasattr(self._clazz, "_from_java"):
+ raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"
+ % self._clazz)
+ return self._clazz._from_java(java_obj)
def context(self, sqlContext):
"""Sets the SQL context to use for loading."""
@@ -164,7 +209,7 @@ class JavaMLReader(object):
if clazz.__name__ in ("Pipeline", "PipelineModel"):
# Remove the last package name "pipeline" for Pipeline and PipelineModel.
java_package = ".".join(java_package.split(".")[0:-1])
- return ".".join([java_package, clazz.__name__])
+ return java_package + "." + clazz.__name__
@classmethod
def _load_java_obj(cls, clazz):
@@ -181,7 +226,7 @@ class MLReadable(object):
"""
.. note:: Experimental
- Mixin for instances that provide JavaMLReader.
+ Mixin for instances that provide :py:class:`MLReader`.
.. versionadded:: 2.0.0
"""
@@ -189,9 +234,21 @@ class MLReadable(object):
@classmethod
def read(cls):
"""Returns an JavaMLReader instance for this class."""
- return JavaMLReader(cls)
+ raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls)
@classmethod
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
return cls.read().load(path)
+
+
+@inherit_doc
+class JavaMLReadable(MLReadable):
+ """
+ (Private) Mixin for instances that provide JavaMLReader.
+ """
+
+ @classmethod
+ def read(cls):
+ """Returns an JavaMLReader instance for this class."""
+ return JavaMLReader(cls)