aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-12 13:56:41 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-12 13:56:50 -0800
commit16da988c5cdae935151e307a66a5385bac5167c3 (patch)
tree1acbebd98154901f1009f748f2ede169775660c8 /python
parent127c19b449315bdeba758e48371291c61abf0952 (diff)
downloadspark-16da988c5cdae935151e307a66a5385bac5167c3.tar.gz
spark-16da988c5cdae935151e307a66a5385bac5167c3.tar.bz2
spark-16da988c5cdae935151e307a66a5385bac5167c3.zip
[SPARK-4369] [MLLib] fix TreeModel.predict() with RDD
Fix TreeModel.predict() with RDD, added tests for it. (Also checked that other models don't have this issue) Author: Davies Liu <davies@databricks.com> Closes #3230 from davies/predict and squashes the following commits: 81172aa [Davies Liu] fix predict (cherry picked from commit bd86118c4e980f94916f892c76fb808fd4c8bd85) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/tree.py26
1 files changed, 14 insertions, 12 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 5d1a3c0962..ef0d556fac 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -124,10 +124,13 @@ class DecisionTree(object):
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
- >>> model.predict(array([1.0])) > 0
- True
- >>> model.predict(array([0.0])) == 0
- True
+ >>> model.predict(array([1.0]))
+ 1.0
+ >>> model.predict(array([0.0]))
+ 0.0
+ >>> rdd = sc.parallelize([[1.0], [0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
"""
return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo,
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
@@ -170,14 +173,13 @@ class DecisionTree(object):
... ]
>>>
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})
- >>> model.predict(array([0.0, 1.0])) == 1
- True
- >>> model.predict(array([0.0, 0.0])) == 0
- True
- >>> model.predict(SparseVector(2, {1: 1.0})) == 1
- True
- >>> model.predict(SparseVector(2, {1: 0.0})) == 0
- True
+ >>> model.predict(SparseVector(2, {1: 1.0}))
+ 1.0
+ >>> model.predict(SparseVector(2, {1: 0.0}))
+ 0.0
+ >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
"""
return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo,
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)