aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authormovelikeriver <mars.lenjoy@gmail.com>2016-02-22 23:58:54 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-22 23:58:54 -0800
commit5cd3e6f60b839909210500b319cf312de026dd49 (patch)
tree10ccd407c888a21c351dac6ec4ab980b2bcf0b10 /examples
parent764ca18037b6b1884fbc4be9a011714a81495020 (diff)
downloadspark-5cd3e6f60b839909210500b319cf312de026dd49.tar.gz
spark-5cd3e6f60b839909210500b319cf312de026dd49.tar.bz2
spark-5cd3e6f60b839909210500b319cf312de026dd49.zip
[SPARK-13257][IMPROVEMENT] Refine naive Bayes example by checking model after loading it
Refine naive Bayes example by checking model after loading it Author: movelikeriver <mars.lenjoy@gmail.com> Closes #11125 from movelikeriver/naive_bayes.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/python/mllib/naive_bayes_example.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/examples/src/main/python/mllib/naive_bayes_example.py b/examples/src/main/python/mllib/naive_bayes_example.py
index f5e120c678..e7d5893d67 100644
--- a/examples/src/main/python/mllib/naive_bayes_example.py
+++ b/examples/src/main/python/mllib/naive_bayes_example.py
@@ -17,9 +17,15 @@
"""
NaiveBayes Example.
+
+Usage:
+ `spark-submit --master local[4] examples/src/main/python/mllib/naive_bayes_example.py`
"""
+
from __future__ import print_function
+import shutil
+
from pyspark import SparkContext
# $example on$
from pyspark.mllib.classification import NaiveBayes, NaiveBayesModel
@@ -50,8 +56,15 @@ if __name__ == "__main__":
# Make prediction and test accuracy.
predictionAndLabel = test.map(lambda p: (model.predict(p.features), p.label))
accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count()
+ print('model accuracy {}'.format(accuracy))
# Save and load model
- model.save(sc, "target/tmp/myNaiveBayesModel")
- sameModel = NaiveBayesModel.load(sc, "target/tmp/myNaiveBayesModel")
+ output_dir = 'target/tmp/myNaiveBayesModel'
+ shutil.rmtree(output_dir, ignore_errors=True)
+ model.save(sc, output_dir)
+ sameModel = NaiveBayesModel.load(sc, output_dir)
+ predictionAndLabel = test.map(lambda p: (sameModel.predict(p.features), p.label))
+ accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count()
+ print('sameModel accuracy {}'.format(accuracy))
+
# $example off$