aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java26
1 files changed, 7 insertions, 19 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java
index 66387b9df5..032c168b94 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java
@@ -27,8 +27,6 @@ import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
@@ -53,31 +51,21 @@ class JavaDecisionTreeClassificationExample {
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
- Integer numClasses = 2;
+ int numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>();
String impurity = "gini";
- Integer maxDepth = 5;
- Integer maxBins = 32;
+ int maxDepth = 5;
+ int maxBins = 32;
// Train a DecisionTree model for classification.
- final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
+ DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
- testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() {
- @Override
- public Tuple2<Double, Double> call(LabeledPoint p) {
- return new Tuple2<>(model.predict(p.features()), p.label());
- }
- });
- Double testErr =
- 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() {
- @Override
- public Boolean call(Tuple2<Double, Double> pl) {
- return !pl._1().equals(pl._2());
- }
- }).count() / testData.count();
+ testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label()));
+ double testErr =
+ predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:\n" + model.toDebugString());