aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java10
1 files changed, 10 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
index 8dd29061da..60585d2727 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -28,6 +28,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.configuration.Algo;
import org.apache.spark.mllib.tree.configuration.Strategy;
@@ -95,6 +97,14 @@ public class JavaDecisionTreeSuite implements Serializable {
DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
+ // java compatibility test
+ JavaRDD<Double> predictions = model.predict(rdd.map(new Function<LabeledPoint, Vector>() {
+ @Override
+ public Vector call(LabeledPoint v1) {
+ return v1.features();
+ }
+ }));
+
int numCorrect = validatePrediction(arr, model);
Assert.assertTrue(numCorrect == rdd.count());
}