diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-03-02 22:27:01 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-02 22:27:01 -0800 |
commit | 7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b (patch) | |
tree | 4fc615db1b5144cf7b430ea3bc26bda2cd49cad8 /python/pyspark/mllib/tree.py | |
parent | 54d19689ff8d786acde5b8ada6741854ffadadea (diff) | |
download | spark-7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b.tar.gz spark-7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b.tar.bz2 spark-7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b.zip |
[SPARK-6097][MLLIB] Support tree model save/load in PySpark/MLlib
Similar to `MatrixFactorizaionModel`, we only need wrappers to support save/load for tree models in Python.
jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #4854 from mengxr/SPARK-6097 and squashes the following commits:
4586a4d [Xiangrui Meng] fix more typos
8ebcac2 [Xiangrui Meng] fix python style
91172d8 [Xiangrui Meng] fix typos
201b3b9 [Xiangrui Meng] update user guide
b5158e2 [Xiangrui Meng] support tree model save/load in PySpark/MLlib
Diffstat (limited to 'python/pyspark/mllib/tree.py')
-rw-r--r-- | python/pyspark/mllib/tree.py | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 73618f0449..bf288d7644 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -23,12 +23,13 @@ from pyspark import SparkContext, RDD from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable __all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', 'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees'] -class TreeEnsembleModel(JavaModelWrapper): +class TreeEnsembleModel(JavaModelWrapper, JavaSaveable): def predict(self, x): """ Predict values for a single data point or an RDD of points using @@ -66,7 +67,7 @@ class TreeEnsembleModel(JavaModelWrapper): return self._java_model.toDebugString() -class DecisionTreeModel(JavaModelWrapper): +class DecisionTreeModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ .. note:: Experimental @@ -103,6 +104,10 @@ class DecisionTreeModel(JavaModelWrapper): """ full model. """ return self._java_model.toDebugString() + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.DecisionTreeModel" + class DecisionTree(object): """ @@ -227,13 +232,17 @@ class DecisionTree(object): @inherit_doc -class RandomForestModel(TreeEnsembleModel): +class RandomForestModel(TreeEnsembleModel, JavaLoader): """ .. note:: Experimental Represents a random forest model. """ + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.RandomForestModel" + class RandomForest(object): """ @@ -406,13 +415,17 @@ class RandomForest(object): @inherit_doc -class GradientBoostedTreesModel(TreeEnsembleModel): +class GradientBoostedTreesModel(TreeEnsembleModel, JavaLoader): """ .. note:: Experimental Represents a gradient-boosted tree model. """ + @classmethod + def _java_loader_class(cls): + return "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel" + class GradientBoostedTrees(object): """ |