diff options
Diffstat (limited to 'examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java | 14 |
1 files changed, 4 insertions, 10 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java index 7fc371ec0f..26b8a6e9fa 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLogisticRegressionWithLBFGSExample.java @@ -23,8 +23,8 @@ import org.apache.spark.SparkContext; // $example on$ import scala.Tuple2; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.classification.LogisticRegressionModel; import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS; import org.apache.spark.mllib.evaluation.MulticlassMetrics; @@ -49,19 +49,13 @@ public class JavaLogisticRegressionWithLBFGSExample { JavaRDD<LabeledPoint> test = splits[1]; // Run training algorithm to build the model. - final LogisticRegressionModel model = new LogisticRegressionWithLBFGS() + LogisticRegressionModel model = new LogisticRegressionWithLBFGS() .setNumClasses(10) .run(training.rdd()); // Compute raw scores on the test set. - JavaRDD<Tuple2<Object, Object>> predictionAndLabels = test.map( - new Function<LabeledPoint, Tuple2<Object, Object>>() { - public Tuple2<Object, Object> call(LabeledPoint p) { - Double prediction = model.predict(p.features()); - return new Tuple2<Object, Object>(prediction, p.label()); - } - } - ); + JavaPairRDD<Object, Object> predictionAndLabels = test.mapToPair(p -> + new Tuple2<>(model.predict(p.features()), p.label())); // Get evaluation metrics. MulticlassMetrics metrics = new MulticlassMetrics(predictionAndLabels.rdd()); |