diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-03-02 22:27:01 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-02 22:27:01 -0800 |
commit | 7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b (patch) | |
tree | 4fc615db1b5144cf7b430ea3bc26bda2cd49cad8 /python/pyspark/mllib/util.py | |
parent | 54d19689ff8d786acde5b8ada6741854ffadadea (diff) | |
download | spark-7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b.tar.gz spark-7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b.tar.bz2 spark-7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b.zip |
[SPARK-6097][MLLIB] Support tree model save/load in PySpark/MLlib
Similar to `MatrixFactorizaionModel`, we only need wrappers to support save/load for tree models in Python.
jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #4854 from mengxr/SPARK-6097 and squashes the following commits:
4586a4d [Xiangrui Meng] fix more typos
8ebcac2 [Xiangrui Meng] fix python style
91172d8 [Xiangrui Meng] fix typos
201b3b9 [Xiangrui Meng] update user guide
b5158e2 [Xiangrui Meng] support tree model save/load in PySpark/MLlib
Diffstat (limited to 'python/pyspark/mllib/util.py')
-rw-r--r-- | python/pyspark/mllib/util.py | 37 |
1 files changed, 33 insertions, 4 deletions
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py index 17d43eadba..e877c720ac 100644 --- a/python/pyspark/mllib/util.py +++ b/python/pyspark/mllib/util.py @@ -18,7 +18,7 @@ import numpy as np import warnings -from pyspark.mllib.common import callMLlibFunc +from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper, inherit_doc from pyspark.mllib.linalg import Vectors, SparseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint @@ -191,6 +191,17 @@ class Saveable(object): raise NotImplementedError +@inherit_doc +class JavaSaveable(Saveable): + """ + Mixin for models that provide save() through their Scala + implementation. + """ + + def save(self, sc, path): + self._java_model.save(sc._jsc.sc(), path) + + class Loader(object): """ Mixin for classes which can load saved models from files. @@ -210,6 +221,7 @@ class Loader(object): raise NotImplemented +@inherit_doc class JavaLoader(Loader): """ Mixin for classes which can load saved models using its Scala @@ -217,13 +229,30 @@ class JavaLoader(Loader): """ @classmethod - def load(cls, sc, path): + def _java_loader_class(cls): + """ + Returns the full class name of the Java loader. The default + implementation replaces "pyspark" by "org.apache.spark" in + the Python full class name. + """ java_package = cls.__module__.replace("pyspark", "org.apache.spark") - java_class = ".".join([java_package, cls.__name__]) + return ".".join([java_package, cls.__name__]) + + @classmethod + def _load_java(cls, sc, path): + """ + Load a Java model from the given path. + """ + java_class = cls._java_loader_class() java_obj = sc._jvm for name in java_class.split("."): java_obj = getattr(java_obj, name) - return cls(java_obj.load(sc._jsc.sc(), path)) + return java_obj.load(sc._jsc.sc(), path) + + @classmethod + def load(cls, sc, path): + java_model = cls._load_java(sc, path) + return cls(java_model) def _test(): |