diff options
author | Davies Liu <davies@databricks.com> | 2014-11-12 13:56:41 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-12 13:56:41 -0800 |
commit | bd86118c4e980f94916f892c76fb808fd4c8bd85 (patch) | |
tree | 46646849b27d832fdf08c05de84c34b59f8916a7 /mllib/src/main | |
parent | a5ef58113667ff73562ce6db381cff96a0b354b0 (diff) | |
download | spark-bd86118c4e980f94916f892c76fb808fd4c8bd85.tar.gz spark-bd86118c4e980f94916f892c76fb808fd4c8bd85.tar.bz2 spark-bd86118c4e980f94916f892c76fb808fd4c8bd85.zip |
[SPARK-4369] [MLLib] fix TreeModel.predict() with RDD
Fix TreeModel.predict() with RDD, added tests for it.
(Also checked that other models don't have this issue)
Author: Davies Liu <davies@databricks.com>
Closes #3230 from davies/predict and squashes the following commits:
81172aa [Davies Liu] fix predict
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala index ec1d99ab26..ac4d02ee39 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.tree.model +import org.apache.spark.api.java.JavaRDD import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.rdd.RDD @@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable features.map(x => predict(x)) } + + /** + * Predict values for the given data set using the model trained. + * + * @param features JavaRDD representing data points to be predicted + * @return JavaRDD of predictions for each of the given data points + */ + def predict(features: JavaRDD[Vector]): JavaRDD[Double] = { + predict(features.rdd) + } + /** * Get number of nodes in tree, including leaf nodes. */ |