aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java14
1 files changed, 4 insertions, 10 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java
index 7fc371ec0f..26b8a6e9fa 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java
@@ -23,8 +23,8 @@ import org.apache.spark.SparkContext;
// $example on$
import scala.Tuple2;
+import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.classification.LogisticRegressionModel;
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
@@ -49,19 +49,13 @@ public class JavaLogisticRegressionWithLBFGSExample {
JavaRDD<LabeledPoint> test = splits[1];
// Run training algorithm to build the model.
- final LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
+ LogisticRegressionModel model = new LogisticRegressionWithLBFGS()
.setNumClasses(10)
.run(training.rdd());
// Compute raw scores on the test set.
- JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map(
- new Function<LabeledPoint, Tuple2<Object, Object>>() {
- public Tuple2<Object, Object> call(LabeledPoint p) {
- Double prediction = model.predict(p.features());
- return new Tuple2<Object, Object>(prediction, p.label());
- }
- }
- );
+ JavaPairRDD<Object, Object> predictionAndLabels = test.mapToPair(p ->
+ new Tuple2<>(model.predict(p.features()), p.label()));
// Get evaluation metrics.
MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd());