aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java13
1 files changed, 3 insertions, 10 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java
index 2d12bdd2a6..03670383b7 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java
@@ -21,7 +21,6 @@ package org.apache.spark.examples.mllib;
import scala.Tuple2;
import org.apache.spark.api.java.*;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
@@ -46,19 +45,13 @@ public class JavaMulticlassClassificationMetricsExample {
JavaRDD<LabeledPoint> test = splits[1];
// Run training algorithm to build the model.
- final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
+ LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(3)
.run(training.rdd());
// Compute raw scores on the test set.
- JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
- new Function<LabeledPoint, Tuple2<Object, Object>>() {
- public Tuple2<Object, Object> call(LabeledPoint p) {
- Double prediction = model.predict(p.features());
- return new Tuple2<Object, Object>(prediction, p.label());
- }
- }
- );
+ JavaPairRDD<Object, Object> predictionAndLabels = test.mapToPair(p ->
+ new Tuple2<>(model.predict(p.features()), p.label()));
// Get evaluation metrics.
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());