aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/tree.py14
-rwxr-xr-xpython/run-tests1
2 files changed, 9 insertions, 6 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index e1a4671709..e9d778df5a 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -88,7 +88,8 @@ class DecisionTree(object):
It will probably be modified for Spark v1.2.
Example usage:
- >>> from numpy import array, ndarray
+ >>> from numpy import array
+ >>> import sys
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
>>> from pyspark.mllib.linalg import SparseVector
@@ -99,15 +100,15 @@ class DecisionTree(object):
... LabeledPoint(1.0, [2.0]),
... LabeledPoint(1.0, [3.0])
... ]
- >>>
- >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2)
- >>> print(model)
+ >>> categoricalFeaturesInfo = {} # no categorical features
+ >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2,
+ ... categoricalFeaturesInfo=categoricalFeaturesInfo)
+ >>> sys.stdout.write(model)
DecisionTreeModel classifier
If (feature 0 <= 0.5)
Predict: 0.0
Else (feature 0 > 0.5)
Predict: 1.0
-
>>> model.predict(array([1.0])) > 0
True
>>> model.predict(array([0.0])) == 0
@@ -119,7 +120,8 @@ class DecisionTree(object):
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>>
- >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data))
+ >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data),
+ ... categoricalFeaturesInfo=categoricalFeaturesInfo)
>>> model.predict(array([0.0, 1.0])) == 1
True
>>> model.predict(array([0.0, 0.0])) == 0
diff --git a/python/run-tests b/python/run-tests
index 1218edcbd7..a6271e0cf5 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -79,6 +79,7 @@ run_test "pyspark/mllib/random.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
run_test "pyspark/mllib/tests.py"
+run_test "pyspark/mllib/tree.py"
run_test "pyspark/mllib/util.py"
if [[ $FAILED == 0 ]]; then