aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/ml/classification.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 291320f881..5978d8f4d3 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -347,6 +347,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel):
@inherit_doc
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
+ HasRawPredictionCol, HasProbabilityCol,
DecisionTreeParams, HasCheckpointInterval):
"""
`http://en.wikipedia.org/wiki/Random_forest Random Forest`
@@ -354,6 +355,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
It supports both binary and multiclass labels, as well as both continuous and categorical
features.
+ >>> import numpy
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
@@ -368,8 +370,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
- >>> model.transform(test0).head().prediction
+ >>> result = model.transform(test0).head()
+ >>> result.prediction
0.0
+ >>> numpy.argmax(result.probability)
+ 0
+ >>> numpy.argmax(result.rawPrediction)
+ 0
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
@@ -390,11 +397,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
numTrees=20, featureSubsetStrategy="auto", seed=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
numTrees=20, featureSubsetStrategy="auto", seed=None)
@@ -427,11 +436,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
@keyword_only
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="gini", numTrees=20, featureSubsetStrategy="auto"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
impurity="gini", numTrees=20, featureSubsetStrategy="auto")