aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-12 13:56:41 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-12 13:56:41 -0800
commitbd86118c4e980f94916f892c76fb808fd4c8bd85 (patch)
tree46646849b27d832fdf08c05de84c34b59f8916a7 /mllib
parenta5ef58113667ff73562ce6db381cff96a0b354b0 (diff)
downloadspark-bd86118c4e980f94916f892c76fb808fd4c8bd85.tar.gz
spark-bd86118c4e980f94916f892c76fb808fd4c8bd85.tar.bz2
spark-bd86118c4e980f94916f892c76fb808fd4c8bd85.zip
[SPARK-4369] [MLLib] fix TreeModel.predict() with RDD
Fix TreeModel.predict() with RDD, added tests for it. (Also checked that other models don't have this issue) Author: Davies Liu <davies@databricks.com> Closes #3230 from davies/predict and squashes the following commits: 81172aa [Davies Liu] fix predict
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala12
1 files changed, 12 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index ec1d99ab26..ac4d02ee39 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.tree.model
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD
@@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
features.map(x => predict(x))
}
+
+ /**
+ * Predict values for the given data set using the model trained.
+ *
+ * @param features JavaRDD representing data points to be predicted
+ * @return JavaRDD of predictions for each of the given data points
+ */
+ def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
+ predict(features.rdd)
+ }
+
/**
* Get number of nodes in tree, including leaf nodes.
*/