aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-07-29 18:18:29 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-29 18:18:29 -0700
commit37c2d1927cebdd19a14c054f670cb0fb9a263586 (patch)
treeab1d400f9795dc429c511083005d02ffca4ea7da /python
parent103d8cce78533b38b4f8060b30f7f455113bc6b5 (diff)
downloadspark-37c2d1927cebdd19a14c054f670cb0fb9a263586.tar.gz
spark-37c2d1927cebdd19a14c054f670cb0fb9a263586.tar.bz2
spark-37c2d1927cebdd19a14c054f670cb0fb9a263586.zip
[SPARK-9016] [ML] make random forest classifiers implement classification trait
Implement the classification trait for RandomForestClassifiers. The plan is to use this in the future to providing thresholding for RandomForestClassifiers (as well as other classifiers that implement that trait). Author: Holden Karau <holden@pigscanfly.ca> Closes #7432 from holdenk/SPARK-9016-make-random-forest-classifiers-implement-classification-trait and squashes the following commits: bf22fa6 [Holden Karau] Add missing imports for testing suite e948f0d [Holden Karau] Check the prediction generation from rawprediciton 25320c3 [Holden Karau] Don't supply numClasses when not needed, assert model classes are as expected 1a67e04 [Holden Karau] Use old decission tree stuff instead 673e0c3 [Holden Karau] Merge branch 'master' into SPARK-9016-make-random-forest-classifiers-implement-classification-trait 0d15b96 [Holden Karau] FIx typo 5eafad4 [Holden Karau] add a constructor for rootnode + num classes fc6156f [Holden Karau] scala style fix 2597915 [Holden Karau] take num classes in constructor 3ccfe4a [Holden Karau] Merge in master, make pass numClasses through randomforest for training 222a10b [Holden Karau] Increase numtrees to 3 in the python test since before the two were equal and the argmax was selecting the last one 16aea1c [Holden Karau] Make tests match the new models b454a02 [Holden Karau] Make the Tree classifiers extends the Classifier base class 77b4114 [Holden Karau] Import vectors lib
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 89117e4928..5a82bc286d 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -299,9 +299,9 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> si_model = stringIndexer.fit(df)
>>> td = si_model.transform(df)
- >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
+ >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
- >>> allclose(model.treeWeights, [1.0, 1.0])
+ >>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction