From bd86118c4e980f94916f892c76fb808fd4c8bd85 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 12 Nov 2014 13:56:41 -0800 Subject: [SPARK-4369] [MLLib] fix TreeModel.predict() with RDD Fix TreeModel.predict() with RDD, added tests for it. (Also checked that other models don't have this issue) Author: Davies Liu Closes #3230 from davies/predict and squashes the following commits: 81172aa [Davies Liu] fix predict --- python/pyspark/mllib/tree.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) (limited to 'python/pyspark/mllib/tree.py') diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 5d1a3c0962..ef0d556fac 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -124,10 +124,13 @@ class DecisionTree(object): Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 - True - >>> model.predict(array([0.0])) == 0 - True + >>> model.predict(array([1.0])) + 1.0 + >>> model.predict(array([0.0])) + 0.0 + >>> rdd = sc.parallelize([[1.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @@ -170,14 +173,13 @@ class DecisionTree(object): ... ] >>> >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {}) - >>> model.predict(array([0.0, 1.0])) == 1 - True - >>> model.predict(array([0.0, 0.0])) == 0 - True - >>> model.predict(SparseVector(2, {1: 1.0})) == 1 - True - >>> model.predict(SparseVector(2, {1: 0.0})) == 0 - True + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {1: 0.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) -- cgit v1.2.3