aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala12
-rw-r--r--python/pyspark/mllib/tree.py26
2 files changed, 26 insertions, 12 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index ec1d99ab26..ac4d02ee39 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.tree.model
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD
@@ -52,6 +53,17 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
features.map(x => predict(x))
}
+
+ /**
+ * Predict values for the given data set using the model trained.
+ *
+ * @param features JavaRDD representing data points to be predicted
+ * @return JavaRDD of predictions for each of the given data points
+ */
+ def predict(features: JavaRDD[Vector]): JavaRDD[Double] = {
+ predict(features.rdd)
+ }
+
/**
* Get number of nodes in tree, including leaf nodes.
*/
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 5d1a3c0962..ef0d556fac 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -124,10 +124,13 @@ class DecisionTree(object):
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
- >>> model.predict(array([1.0])) > 0
- True
- >>> model.predict(array([0.0])) == 0
- True
+ >>> model.predict(array([1.0]))
+ 1.0
+ >>> model.predict(array([0.0]))
+ 0.0
+ >>> rdd = sc.parallelize([[1.0], [0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
"""
return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
@@ -170,14 +173,13 @@ class DecisionTree(object):
... ]
>>>
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})
- >>> model.predict(array([0.0, 1.0])) == 1
- True
- >>> model.predict(array([0.0, 0.0])) == 0
- True
- >>> model.predict(SparseVector(2, {1: 1.0})) == 1
- True
- >>> model.predict(SparseVector(2, {1: 0.0})) == 0
- True
+ >>> model.predict(SparseVector(2, {1: 1.0}))
+ 1.0
+ >>> model.predict(SparseVector(2, {1: 0.0}))
+ 0.0
+ >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
"""
return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo,
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)