diff options
Diffstat (limited to 'python/pyspark/ml/util.py')
-rw-r--r-- | python/pyspark/ml/util.py | 89 |
1 files changed, 73 insertions, 16 deletions
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 42801c91bb..6703851262 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -74,18 +74,38 @@ class Identifiable(object): @inherit_doc -class JavaMLWriter(object): +class MLWriter(object): """ .. note:: Experimental - Utility class that can save ML instances through their Scala implementation. + Utility class that can save ML instances. .. versionadded:: 2.0.0 """ + def save(self, path): + """Save the ML instance to the input path.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def overwrite(self): + """Overwrites if the output path already exists.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types + """ + def __init__(self, instance): - instance._transfer_params_to_java() - self._jwrite = instance._java_obj.write() + super(JavaMLWriter, self).__init__() + _java_obj = instance._to_java() + self._jwrite = _java_obj.write() def save(self, path): """Save the ML instance to the input path.""" @@ -109,14 +129,14 @@ class MLWritable(object): """ .. note:: Experimental - Mixin for ML instances that provide JavaMLWriter. + Mixin for ML instances that provide :py:class:`MLWriter`. .. versionadded:: 2.0.0 """ def write(self): """Returns an JavaMLWriter instance for this ML instance.""" - return JavaMLWriter(self) + raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self)) def save(self, path): """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" @@ -124,15 +144,41 @@ class MLWritable(object): @inherit_doc -class JavaMLReader(object): +class JavaMLWritable(MLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. + """ + + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + +@inherit_doc +class MLReader(object): """ .. note:: Experimental - Utility class that can load ML instances through their Scala implementation. + Utility class that can load ML instances. .. versionadded:: 2.0.0 """ + def load(self, path): + """Load the ML instance from the input path.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types + """ + def __init__(self, clazz): self._clazz = clazz self._jread = self._load_java_obj(clazz).read() @@ -142,11 +188,10 @@ class JavaMLReader(object): if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) java_obj = self._jread.load(path) - instance = self._clazz() - instance._java_obj = java_obj - instance._resetUid(java_obj.uid()) - instance._transfer_params_from_java() - return instance + if not hasattr(self._clazz, "_from_java"): + raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r" + % self._clazz) + return self._clazz._from_java(java_obj) def context(self, sqlContext): """Sets the SQL context to use for loading.""" @@ -164,7 +209,7 @@ class JavaMLReader(object): if clazz.__name__ in ("Pipeline", "PipelineModel"): # Remove the last package name "pipeline" for Pipeline and PipelineModel. java_package = ".".join(java_package.split(".")[0:-1]) - return ".".join([java_package, clazz.__name__]) + return java_package + "." + clazz.__name__ @classmethod def _load_java_obj(cls, clazz): @@ -181,7 +226,7 @@ class MLReadable(object): """ .. note:: Experimental - Mixin for instances that provide JavaMLReader. + Mixin for instances that provide :py:class:`MLReader`. .. versionadded:: 2.0.0 """ @@ -189,9 +234,21 @@ class MLReadable(object): @classmethod def read(cls): """Returns an JavaMLReader instance for this class.""" - return JavaMLReader(cls) + raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls) @classmethod def load(cls, path): """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" return cls.read().load(path) + + +@inherit_doc +class JavaMLReadable(MLReadable): + """ + (Private) Mixin for instances that provide JavaMLReader. + """ + + @classmethod + def read(cls): + """Returns an JavaMLReader instance for this class.""" + return JavaMLReader(cls) |