From e375456063617cd7000d796024f41e5927f21edd Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 4 Aug 2015 14:54:26 -0700 Subject: [SPARK-9447] [ML] [PYTHON] Added HasRawPredictionCol, HasProbabilityCol to RandomForestClassifier Added HasRawPredictionCol, HasProbabilityCol to RandomForestClassifier, plus doc tests for those columns. CC: holdenk yanboliang Author: Joseph K. Bradley Closes #7903 from jkbradley/rf-prob-python and squashes the following commits: c62a83f [Joseph K. Bradley] made unit test more robust 14eeba2 [Joseph K. Bradley] added HasRawPredictionCol, HasProbabilityCol to RandomForestClassifier in PySpark --- python/pyspark/ml/classification.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'python/pyspark/ml/classification.py') diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 291320f881..5978d8f4d3 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -347,6 +347,7 @@ class DecisionTreeClassificationModel(DecisionTreeModel): @inherit_doc class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, + HasRawPredictionCol, HasProbabilityCol, DecisionTreeParams, HasCheckpointInterval): """ `http://en.wikipedia.org/wiki/Random_forest Random Forest` @@ -354,6 +355,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred It supports both binary and multiclass labels, as well as both continuous and categorical features. + >>> import numpy >>> from numpy import allclose >>> from pyspark.mllib.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer @@ -368,8 +370,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) - >>> model.transform(test0).head().prediction + >>> result = model.transform(test0).head() + >>> result.prediction 0.0 + >>> numpy.argmax(result.probability) + 0 + >>> numpy.argmax(result.rawPrediction) + 0 >>> test1 = sqlContext.createDataFrame([(Vectors.sparse(1, [0], [1.0]),)], ["features"]) >>> model.transform(test1).head().prediction 1.0 @@ -390,11 +397,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", numTrees=20, featureSubsetStrategy="auto", seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ numTrees=20, featureSubsetStrategy="auto", seed=None) @@ -427,11 +436,13 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred @keyword_only def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", + probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, impurity="gini", numTrees=20, featureSubsetStrategy="auto"): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ + probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \ impurity="gini", numTrees=20, featureSubsetStrategy="auto") -- cgit v1.2.3