diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/mllib/tree.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index e1a4671709..e9d778df5a 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -88,7 +88,8 @@ class DecisionTree(object): It will probably be modified for Spark v1.2. Example usage: - >>> from numpy import array, ndarray + >>> from numpy import array + >>> import sys >>> from pyspark.mllib.regression import LabeledPoint >>> from pyspark.mllib.tree import DecisionTree >>> from pyspark.mllib.linalg import SparseVector @@ -99,15 +100,15 @@ class DecisionTree(object): ... LabeledPoint(1.0, [2.0]), ... LabeledPoint(1.0, [3.0]) ... ] - >>> - >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2) - >>> print(model) + >>> categoricalFeaturesInfo = {} # no categorical features + >>> model = DecisionTree.trainClassifier(sc.parallelize(data), numClasses=2, + ... categoricalFeaturesInfo=categoricalFeaturesInfo) + >>> sys.stdout.write(model) DecisionTreeModel classifier If (feature 0 <= 0.5) Predict: 0.0 Else (feature 0 > 0.5) Predict: 1.0 - >>> model.predict(array([1.0])) > 0 True >>> model.predict(array([0.0])) == 0 @@ -119,7 +120,8 @@ class DecisionTree(object): ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) ... ] >>> - >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data)) + >>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), + ... categoricalFeaturesInfo=categoricalFeaturesInfo) >>> model.predict(array([0.0, 1.0])) == 1 True >>> model.predict(array([0.0, 0.0])) == 0 |