diff options
Diffstat (limited to 'examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java | 26 |
1 files changed, 7 insertions, 19 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java index 66387b9df5..032c168b94 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaDecisionTreeClassificationExample.java @@ -27,8 +27,6 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.DecisionTree; import org.apache.spark.mllib.tree.model.DecisionTreeModel; @@ -53,31 +51,21 @@ class JavaDecisionTreeClassificationExample { // Set parameters. // Empty categoricalFeaturesInfo indicates all features are continuous. - Integer numClasses = 2; + int numClasses = 2; Map<Integer, Integer> categoricalFeaturesInfo = new HashMap<>(); String impurity = "gini"; - Integer maxDepth = 5; - Integer maxBins = 32; + int maxDepth = 5; + int maxBins = 32; // Train a DecisionTree model for classification. - final DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, + DecisionTreeModel model = DecisionTree.trainClassifier(trainingData, numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins); // Evaluate model on test instances and compute test error JavaPairRDD<Double, Double> predictionAndLabel = - testData.mapToPair(new PairFunction<LabeledPoint, Double, Double>() { - @Override - public Tuple2<Double, Double> call(LabeledPoint p) { - return new Tuple2<>(model.predict(p.features()), p.label()); - } - }); - Double testErr = - 1.0 * predictionAndLabel.filter(new Function<Tuple2<Double, Double>, Boolean>() { - @Override - public Boolean call(Tuple2<Double, Double> pl) { - return !pl._1().equals(pl._2()); - } - }).count() / testData.count(); + testData.mapToPair(p -> new Tuple2<>(model.predict(p.features()), p.label())); + double testErr = + predictionAndLabel.filter(pl -> !pl._1().equals(pl._2())).count() / (double) testData.count(); System.out.println("Test Error: " + testErr); System.out.println("Learned classification tree model:\n" + model.toDebugString()); |