aboutsummaryrefslogtreecommitdiff
path: root/docs/mllib-decision-tree.md
diff options
context:
space:
mode:
Diffstat (limited to 'docs/mllib-decision-tree.md')
-rw-r--r--docs/mllib-decision-tree.md20
1 files changed, 20 insertions, 0 deletions
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index 6675133a81..4695d1cde4 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -194,6 +194,7 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
<div data-lang="scala">
{% highlight scala %}
import org.apache.spark.mllib.tree.DecisionTree
+import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
@@ -221,6 +222,9 @@ val labelAndPreds = testData.map { point =>
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println("Test Error = " + testErr)
println("Learned classification tree model:\n" + model.toDebugString)
+
+model.save("myModelPath")
+val sameModel = DecisionTreeModel.load("myModelPath")
{% endhighlight %}
</div>
@@ -279,10 +283,16 @@ Double testErr =
}).count() / testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:\n" + model.toDebugString());
+
+model.save("myModelPath");
+DecisionTreeModel sameModel = DecisionTreeModel.load("myModelPath");
{% endhighlight %}
</div>
<div data-lang="python">
+
+Note that the Python API does not yet support model save/load but will in the future.
+
{% highlight python %}
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
@@ -324,6 +334,7 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
<div data-lang="scala">
{% highlight scala %}
import org.apache.spark.mllib.tree.DecisionTree
+import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
@@ -350,6 +361,9 @@ val labelsAndPredictions = testData.map { point =>
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Test Mean Squared Error = " + testMSE)
println("Learned regression tree model:\n" + model.toDebugString)
+
+model.save("myModelPath")
+val sameModel = DecisionTreeModel.load("myModelPath")
{% endhighlight %}
</div>
@@ -414,10 +428,16 @@ Double testMSE =
}) / data.count();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression tree model:\n" + model.toDebugString());
+
+model.save("myModelPath");
+DecisionTreeModel sameModel = DecisionTreeModel.load("myModelPath");
{% endhighlight %}
</div>
<div data-lang="python">
+
+Note that the Python API does not yet support model save/load but will in the future.
+
{% highlight python %}
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree