aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java19
1 files changed, 4 insertions, 15 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java
index f4ec04b0c6..d80dbe8000 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaNaiveBayesExample.java
@@ -19,8 +19,6 @@ package org.apache.spark.examples.mllib;
// $example on$
import scala.Tuple2;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
@@ -41,20 +39,11 @@ public class JavaNaiveBayesExample {
JavaRDD<LabeledPoint>[] tmp = inputData.randomSplit(new double[]{0.6, 0.4});
JavaRDD<LabeledPoint> training = tmp[0]; // training set
JavaRDD<LabeledPoint> test = tmp[1]; // test set
- final NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
+ NaiveBayesModel model = NaiveBayes.train(training.rdd(), 1.0);
JavaPairRDD<Double, Double> predictionAndLabel =
- test.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
- @Override
- public Tuple2<Double, Double> call(LabeledPoint p) {
- return new Tuple2<>(model.predict(p.features()), p.label());
- }
- });
- double accuracy = predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
- @Override
- public Boolean call(Tuple2<Double, Double> pl) {
- return pl._1().equals(pl._2());
- }
- }).count() / (double) test.count();
+ test.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
+ double accuracy =
+ predictionAndLabel.filter(pl -> pl._1().equals(pl._2())).count() / (double) test.count();
// Save and load model
model.save(jsc.sc(), "target/tmp/myNaiveBayesModel");