aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/mllib-naive-bayes.md18
1 files changed, 6 insertions, 12 deletions
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index 4b3a7cab32..1d1d7dcf6f 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -51,9 +51,8 @@ val training = splits(0)
val test = splits(1)
val model = NaiveBayes.train(training, lambda = 1.0)
-val prediction = model.predict(test.map(_.features))
-val predictionAndLabel = prediction.zip(test.map(_.label))
+val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
{% endhighlight %}
</div>
@@ -71,6 +70,7 @@ can be used for evaluation and prediction.
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.classification.NaiveBayes;
import org.apache.spark.mllib.classification.NaiveBayesModel;
import org.apache.spark.mllib.regression.LabeledPoint;
@@ -81,18 +81,12 @@ JavaRDD<LabeledPoint> test = ... // test set
final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
-JavaRDD<Double> prediction =
- test.map(new Function<LabeledPoint, Double>() {
- @Override public Double call(LabeledPoint p) {
- return model.predict(p.features());
- }
- });
JavaPairRDD<Double, Double> predictionAndLabel =
- prediction.zip(test.map(new Function<LabeledPoint, Double>() {
- @Override public Double call(LabeledPoint p) {
- return p.label();
+ test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
+ @Override public Tuple2<Double, Double> call(LabeledPoint p) {
+ return new Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
- }));
+ });
double accuracy = 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
@Override public Boolean call(Tuple2<Double, Double> pl) {
return pl._1() == pl._2();