aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-05-09 09:08:54 +0100
committerSean Owen <sowen@cloudera.com>2016-05-09 09:08:54 +0100
commit68abc1b4e9afbb6c2a87689221a46b835dded102 (patch)
tree654b2c7b384cf5e9a0225e7fd3696140b6cad812
parent635ef407e11dec41ae9bc428935fb8fdaa482f7e (diff)
downloadspark-68abc1b4e9afbb6c2a87689221a46b835dded102.tar.gz
spark-68abc1b4e9afbb6c2a87689221a46b835dded102.tar.bz2
spark-68abc1b4e9afbb6c2a87689221a46b835dded102.zip
[SPARK-14814][MLLIB] API: Java compatibility, docs
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-14814 fix a java compatibility function in mllib DecisionTreeModel. As synced in jira, other compatibility issues don't need fixes. ## How was this patch tested? existing ut Author: Yuhao Yang <hhbyyh@gmail.com> Closes #12971 from hhbyyh/javacompatibility.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala4
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java10
2 files changed, 12 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a87f8a6cde..c13b9a66c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -75,8 +75,8 @@ class DecisionTreeModel @Since("1.0.0") (
* @return JavaRDD of predictions for each of the given data points
*/
@Since("1.2.0")
- def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
- predict(features.rdd)
+ def predict(features: JavaRDD[Vector]): JavaRDD[java.lang.Double] = {
+ predict(features.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Double]]
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
index 8dd29061da..60585d2727 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/tree/JavaDecisionTreeSuite.java
@@ -28,6 +28,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.configuration.Algo;
import org.apache.spark.mllib.tree.configuration.Strategy;
@@ -95,6 +97,14 @@ public class JavaDecisionTreeSuite implements Serializable {
DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
+ // java compatibility test
+ JavaRDD<Double> predictions = model.predict(rdd.map(new Function<LabeledPoint, Vector>() {
+ @Override
+ public Vector call(LabeledPoint v1) {
+ return v1.features();
+ }
+ }));
+
int numCorrect = validatePrediction(arr, model);
Assert.assertTrue(numCorrect == rdd.count());
}