aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java47
1 files changed, 17 insertions, 30 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java
index 9ca9a7847c..324a781c1a 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLinearRegressionWithSGDExample.java
@@ -23,9 +23,8 @@ import org.apache.spark.api.java.JavaSparkContext;
// $example on$
import scala.Tuple2;
-import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
@@ -44,43 +43,31 @@ public class JavaLinearRegressionWithSGDExample {
// Load and parse the data
String path = "data/mllib/ridge-data/lpsa.data";
JavaRDD<String> data = sc.textFile(path);
- JavaRDD<LabeledPoint> parsedData = data.map(
- new Function<String, LabeledPoint>() {
- public LabeledPoint call(String line) {
- String[] parts = line.split(",");
- String[] features = parts[1].split(" ");
- double[] v = new double[features.length];
- for (int i = 0; i < features.length - 1; i++) {
- v[i] = Double.parseDouble(features[i]);
- }
- return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v));
- }
+ JavaRDD<LabeledPoint> parsedData = data.map(line -> {
+ String[] parts = line.split(",");
+ String[] features = parts[1].split(" ");
+ double[] v = new double[features.length];
+ for (int i = 0; i < features.length - 1; i++) {
+ v[i] = Double.parseDouble(features[i]);
}
- );
+ return new LabeledPoint(Double.parseDouble(parts[0]), Vectors.dense(v));
+ });
parsedData.cache();
// Building the model
int numIterations = 100;
double stepSize = 0.00000001;
- final LinearRegressionModel model =
+ LinearRegressionModel model =
LinearRegressionWithSGD.train(JavaRDD.toRDD(parsedData), numIterations, stepSize);
// Evaluate model on training examples and compute training error
- JavaRDD<Tuple2<Double, Double>> valuesAndPreds = parsedData.map(
- new Function<LabeledPoint, Tuple2<Double, Double>>() {
- public Tuple2<Double, Double> call(LabeledPoint point) {
- double prediction = model.predict(point.features());
- return new Tuple2<>(prediction, point.label());
- }
- }
- );
- double MSE = new JavaDoubleRDD(valuesAndPreds.map(
- new Function<Tuple2<Double, Double>, Object>() {
- public Object call(Tuple2<Double, Double> pair) {
- return Math.pow(pair._1() - pair._2(), 2.0);
- }
- }
- ).rdd()).mean();
+ JavaPairRDD<Double, Double> valuesAndPreds = parsedData.mapToPair(point ->
+ new Tuple2<>(model.predict(point.features()), point.label()));
+
+ double MSE = valuesAndPreds.mapToDouble(pair -> {
+ double diff = pair._1() - pair._2();
+ return diff * diff;
+ }).mean();
System.out.println("training Mean Squared Error = " + MSE);
// Save and load model