aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-04 14:54:26 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-04 14:54:26 -0700
commite375456063617cd7000d796024f41e5927f21edd (patch)
treedd910b1f89023728d652b9a962b2c21beafe6af6 /python/pyspark/ml/classification.py
parent9d668b73687e697cad2ef7fd3c3ba405e9795593 (diff)
downloadspark-e375456063617cd7000d796024f41e5927f21edd.tar.gz
spark-e375456063617cd7000d796024f41e5927f21edd.tar.bz2
spark-e375456063617cd7000d796024f41e5927f21edd.zip
[SPARK-9447] [ML] [PYTHON] Added HasRawPredictionCol, HasProbabilityCol to RandomForestClassifier
Added HasRawPredictionCol, HasProbabilityCol to RandomForestClassifier, plus doc tests for those columns. CC: holdenk yanboliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #7903 from jkbradley/rf-prob-python and squashes the following commits: c62a83f [Joseph K. Bradley] made unit test more robust 14eeba2 [Joseph K. Bradley] added HasRawPredictionCol, HasProbabilityCol to RandomForestClassifier in PySpark
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 291320f881..5978d8f4d3 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -347,6 +347,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel):
@inherit_doc
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
+ HasRawPredictionCol, HasProbabilityCol,
DecisionTreeParams, HasCheckpointInterval):
"""
`http://en.wikipedia.org/wiki/Random_forest Random Forest`
@@ -354,6 +355,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
It supports both binary and multiclass labels, as well as both continuous and categorical
features.
+ >>> import numpy
>>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
@@ -368,8 +370,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
- >>> model.transform(test0).head().prediction
+ >>> result = model.transform(test0).head()
+ >>> result.prediction
0.0
+ >>> numpy.argmax(result.probability)
+ 0
+ >>> numpy.argmax(result.rawPrediction)
+ 0
>>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"])
>>> model.transform(test1).head().prediction
1.0
@@ -390,11 +397,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
numTrees=20, featureSubsetStrategy="auto", seed=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
numTrees=20, featureSubsetStrategy="auto", seed=None)
@@ -427,11 +436,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
@keyword_only
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="gini", numTrees=20, featureSubsetStrategy="auto"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
impurity="gini", numTrees=20, featureSubsetStrategy="auto")