diff options
Diffstat (limited to 'examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java | 47 |
1 files changed, 17 insertions, 30 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java index 9ca9a7847c..324a781c1a 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java @@ -23,9 +23,8 @@ import org.apache.spark.api.java.JavaSparkContext; // $example on$ import scala.Tuple2; -import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.regression.LinearRegressionModel; @@ -44,43 +43,31 @@ public class JavaLinearRegressionWithSGDExample { // Load and parse the data String path = "data/mllib/ridge-data/lpsa.data"; JavaRDD<String> data = sc.textFile(path); - JavaRDD<LabeledPoint> parsedData = data.map( - new Function<String, LabeledPoint>() { - public LabeledPoint call(String line) { - String[] parts = line.split(","); - String[] features = parts[1].split(" "); - double[] v = new double[features.length]; - for (int i = 0; i < features.length - 1; i++) { - v[i] = Double.parseDouble(features[i]); - } - return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); - } + JavaRDD<LabeledPoint> parsedData = data.map(line -> { + String[] parts = line.split(","); + String[] features = parts[1].split(" "); + double[] v = new double[features.length]; + for (int i = 0; i < features.length - 1; i++) { + v[i] = Double.parseDouble(features[i]); } - ); + return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v)); + }); parsedData.cache(); // Building the model int numIterations = 100; double stepSize = 0.00000001; - final LinearRegressionModel model = + LinearRegressionModel model = LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize); // Evaluate model on training examples and compute training error - JavaRDD<Tuple2<Double, Double>> valuesAndPreds = parsedData.map( - new Function<LabeledPoint, Tuple2<Double, Double>>() { - public Tuple2<Double, Double> call(LabeledPoint point) { - double prediction = model.predict(point.features()); - return new Tuple2<>(prediction, point.label()); - } - } - ); - double MSE = new JavaDoubleRDD(valuesAndPreds.map( - new Function<Tuple2<Double, Double>, Object>() { - public Object call(Tuple2<Double, Double> pair) { - return Math.pow(pair._1() - pair._2(), 2.0); - } - } - ).rdd()).mean(); + JavaPairRDD<Double, Double> valuesAndPreds = parsedData.mapToPair(point -> + new Tuple2<>(model.predict(point.features()), point.label())); + + double MSE = valuesAndPreds.mapToDouble(pair -> { + double diff = pair._1() - pair._2(); + return diff * diff; + }).mean(); System.out.println("training Mean Squared Error = " + MSE); // Save and load model |