diff options
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/tree.py | 26 |
1 files changed, 14 insertions, 12 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 5d1a3c0962..ef0d556fac 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -124,10 +124,13 @@ class DecisionTree(object): Predict: 0.0 Else (feature 0 > 0.0) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 - True - >>> model.predict(array([0.0])) == 0 - True + >>> model.predict(array([1.0])) + 1.0 + >>> model.predict(array([0.0])) + 0.0 + >>> rdd = sc.parallelize([[1.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) @@ -170,14 +173,13 @@ class DecisionTree(object): ... ] >>> >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {}) - >>> model.predict(array([0.0, 1.0])) == 1 - True - >>> model.predict(array([0.0, 0.0])) == 0 - True - >>> model.predict(SparseVector(2, {1: 1.0})) == 1 - True - >>> model.predict(SparseVector(2, {1: 0.0})) == 0 - True + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {1: 0.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] """ return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) |