diff options
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java index 8dd29061da..60585d2727 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java @@ -28,6 +28,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.mllib.tree.configuration.Algo; import org.apache.spark.mllib.tree.configuration.Strategy; @@ -95,6 +97,14 @@ public class JavaDecisionTreeSuite implements Serializable { DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy); + // java compatibility test + JavaRDD<Double> predictions = model.predict(rdd.map(new Function<LabeledPoint, Vector>() { + @Override + public Vector call(LabeledPoint v1) { + return v1.features(); + } + })); + int numCorrect = validatePrediction(arr, model); Assert.assertTrue(numCorrect == rdd.count()); } |