aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/util.py')
-rw-r--r--python/pyspark/mllib/util.py58
1 files changed, 58 insertions, 0 deletions
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 4ed978b454..17d43eadba 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -168,6 +168,64 @@ class MLUtils(object):
return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
+class Saveable(object):
+ """
+ Mixin for models and transformers which may be saved as files.
+ """
+
+ def save(self, sc, path):
+ """
+ Save this model to the given path.
+
+ This saves:
+ * human-readable (JSON) model metadata to path/metadata/
+ * Parquet formatted data to path/data/
+
+ The model may be loaded using py:meth:`Loader.load`.
+
+ :param sc: Spark context used to save model data.
+ :param path: Path specifying the directory in which to save
+ this model. If the directory already exists,
+ this method throws an exception.
+ """
+ raise NotImplementedError
+
+
+class Loader(object):
+ """
+ Mixin for classes which can load saved models from files.
+ """
+
+ @classmethod
+ def load(cls, sc, path):
+ """
+ Load a model from the given path. The model should have been
+ saved using py:meth:`Saveable.save`.
+
+ :param sc: Spark context used for loading model files.
+ :param path: Path specifying the directory to which the model
+ was saved.
+ :return: model instance
+ """
+ raise NotImplemented
+
+
+class JavaLoader(Loader):
+ """
+ Mixin for classes which can load saved models using its Scala
+ implementation.
+ """
+
+ @classmethod
+ def load(cls, sc, path):
+ java_package = cls.__module__.replace("pyspark", "org.apache.spark")
+ java_class = ".".join([java_package, cls.__name__])
+ java_obj = sc._jvm
+ for name in java_class.split("."):
+ java_obj = getattr(java_obj, name)
+ return cls(java_obj.load(sc._jsc.sc(), path))
+
+
def _test():
import doctest
from pyspark.context import SparkContext