diff options
Diffstat (limited to 'examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java | 13 |
1 files changed, 3 insertions, 10 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java index 2d12bdd2a6..03670383b7 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMulticlassClassificationMetricsExample.java @@ -21,7 +21,6 @@ package org.apache.spark.examples.mllib; import scala.Tuple2; import org.apache.spark.api.java.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.LogisticRegressionModel; import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; import org.apache.spark.mllib.evaluation.MulticlassMetrics; @@ -46,19 +45,13 @@ public class JavaMulticlassClassificationMetricsExample { JavaRDD<LabeledPoint> test = splits[1]; // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(3) .run(training.rdd()); // Compute raw scores on the test set. - JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( - new Function<LabeledPoint, Tuple2<Object, Object>>() { - public Tuple2<Object, Object> call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2<Object, Object>(prediction, p.label()); - } - } - ); + JavaPairRDD<Object, Object> predictionAndLabels = test.mapToPair(p -> + new Tuple2<>(model.predict(p.features()), p.label())); // Get evaluation metrics. MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); |