aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tree.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tree.py')
-rw-r--r--python/pyspark/mllib/tree.py15
1 files changed, 8 insertions, 7 deletions
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 0fe6e4fabe..cfcbea573f 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -482,13 +482,13 @@ class GradientBoostedTrees(object):
... LabeledPoint(1.0, [3.0])
... ]
>>>
- >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {})
+ >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}, numIterations=10)
>>> model.numTrees()
- 100
+ 10
>>> model.totalNumNodes()
- 300
+ 30
>>> print(model) # it already has newline
- TreeEnsembleModel classifier with 100 trees
+ TreeEnsembleModel classifier with 10 trees
<BLANKLINE>
>>> model.predict([2.0])
1.0
@@ -541,11 +541,12 @@ class GradientBoostedTrees(object):
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>>
- >>> model = GradientBoostedTrees.trainRegressor(sc.parallelize(sparse_data), {})
+ >>> data = sc.parallelize(sparse_data)
+ >>> model = GradientBoostedTrees.trainRegressor(data, {}, numIterations=10)
>>> model.numTrees()
- 100
+ 10
>>> model.totalNumNodes()
- 102
+ 12
>>> model.predict(SparseVector(2, {1: 1.0}))
1.0
>>> model.predict(SparseVector(2, {0: 1.0}))