diff options
author | Davies Liu <davies@databricks.com> | 2014-11-12 13:56:41 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-12 13:56:41 -0800 |
commit | bd86118c4e980f94916f892c76fb808fd4c8bd85 (patch) | |
tree | 46646849b27d832fdf08c05de84c34b59f8916a7 /python | |
parent | a5ef58113667ff73562ce6db381cff96a0b354b0 (diff) | |
download | spark-bd86118c4e980f94916f892c76fb808fd4c8bd85.tar.gz spark-bd86118c4e980f94916f892c76fb808fd4c8bd85.tar.bz2 spark-bd86118c4e980f94916f892c76fb808fd4c8bd85.zip |
[SPARK-4369] [MLLib] fix TreeModel.predict() with RDD
Fix TreeModel.predict() with RDD, added tests for it.
(Also checked that other models don't have this issue)
Author: Davies Liu <davies@databricks.com>
Closes #3230 from davies/predict and squashes the following commits:
81172aa [Davies Liu] fix predict
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/tree.py | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 5d1a3c0962..ef0d556fac 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -124,10 +124,13 @@ class DecisionTree(object): Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 - True - >>> model.predict(array([0.0])) == 0 - True + >>> model.predict(array([1.0])) + 1.0 + >>> model.predict(array([0.0])) + 0.0 + >>> rdd = sc.parallelize([[1.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @@ -170,14 +173,13 @@ class DecisionTree(object): ... ] >>> >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {}) - >>> model.predict(array([0.0, 1.0])) == 1 - True - >>> model.predict(array([0.0, 0.0])) == 0 - True - >>> model.predict(SparseVector(2, {1: 1.0})) == 1 - True - >>> model.predict(SparseVector(2, {1: 0.0})) == 0 - True + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {1: 0.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) |