aboutsummaryrefslogtreecommitdiff
path: root/docs/mllib-ensembles.md
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-03-02 22:27:01 -0800
committerXiangrui Meng <meng@databricks.com>2015-03-02 22:27:01 -0800
commit7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b (patch)
tree4fc615db1b5144cf7b430ea3bc26bda2cd49cad8 /docs/mllib-ensembles.md
parent54d19689ff8d786acde5b8ada6741854ffadadea (diff)
downloadspark-7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b.tar.gz
spark-7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b.tar.bz2
spark-7e53a79c30511dbd0e5d9878a4b8b0f5bc94e68b.zip
[SPARK-6097][MLLIB] Support tree model save/load in PySpark/MLlib
Similar to `MatrixFactorizaionModel`, we only need wrappers to support save/load for tree models in Python. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #4854 from mengxr/SPARK-6097 and squashes the following commits: 4586a4d [Xiangrui Meng] fix more typos 8ebcac2 [Xiangrui Meng] fix python style 91172d8 [Xiangrui Meng] fix typos 201b3b9 [Xiangrui Meng] update user guide b5158e2 [Xiangrui Meng] support tree model save/load in PySpark/MLlib
Diffstat (limited to 'docs/mllib-ensembles.md')
-rw-r--r--docs/mllib-ensembles.md32
1 files changed, 20 insertions, 12 deletions
diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md
index ec1ef38b45..cbfb682609 100644
--- a/docs/mllib-ensembles.md
+++ b/docs/mllib-ensembles.md
@@ -202,10 +202,8 @@ RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
<div data-lang="python">
-Note that the Python API does not yet support model save/load but will in the future.
-
{% highlight python %}
-from pyspark.mllib.tree import RandomForest
+from pyspark.mllib.tree import RandomForest, RandomForestModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
@@ -228,6 +226,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
print('Test Error = ' + str(testErr))
print('Learned classification forest model:')
print(model.toDebugString())
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = RandomForestModel.load(sc, "myModelPath")
{% endhighlight %}
</div>
@@ -354,10 +356,8 @@ RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
<div data-lang="python">
-Note that the Python API does not yet support model save/load but will in the future.
-
{% highlight python %}
-from pyspark.mllib.tree import RandomForest
+from pyspark.mllib.tree import RandomForest, RandomForestModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file into an RDD of LabeledPoint.
@@ -380,6 +380,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression forest model:')
print(model.toDebugString())
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = RandomForestModel.load(sc, "myModelPath")
{% endhighlight %}
</div>
@@ -581,10 +585,8 @@ GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "m
<div data-lang="python">
-Note that the Python API does not yet support model save/load but will in the future.
-
{% highlight python %}
-from pyspark.mllib.tree import GradientBoostedTrees
+from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file.
@@ -605,6 +607,10 @@ testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(tes
print('Test Error = ' + str(testErr))
print('Learned classification GBT model:')
print(model.toDebugString())
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
{% endhighlight %}
</div>
@@ -732,10 +738,8 @@ GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "m
<div data-lang="python">
-Note that the Python API does not yet support model save/load but will in the future.
-
{% highlight python %}
-from pyspark.mllib.tree import GradientBoostedTrees
+from pyspark.mllib.tree import GradientBoostedTrees, GradientBoostedTreesModel
from pyspark.mllib.util import MLUtils
# Load and parse the data file.
@@ -756,6 +760,10 @@ testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / flo
print('Test Mean Squared Error = ' + str(testMSE))
print('Learned regression GBT model:')
print(model.toDebugString())
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
{% endhighlight %}
</div>