diff options
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r-- | python/pyspark/ml/classification.py | 44 |
1 files changed, 44 insertions, 0 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 29d1d203f2..ec8834a89e 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -285,6 +285,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred 3 >>> model.depth 1 + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> result = model.transform(test0).head() >>> result.prediction @@ -352,6 +354,27 @@ class DecisionTreeClassificationModel(DecisionTreeModel): .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + This generalizes the idea of "Gini" importance to other losses, + following the explanation of Gini importance from "Random Forests" documentation + by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + + This feature importance is calculated as follows: + - importance(feature j) = sum (over nodes which split on feature j) of the gain, + where gain is scaled by the number of instances passing through node + - Normalize importances for tree to sum to 1. + + Note: Feature importance for single decision trees can have high variance due to + correlated predictor variables. Consider using a :class:`RandomForestClassifier` + to determine feature importance instead. + """ + return self._call_java("featureImportances") + @inherit_doc class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed, @@ -375,6 +398,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> td = si_model.transform(df) >>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) + >>> model.featureImportances + SparseVector(1, {0: 1.0}) >>> allclose(model.treeWeights, [1.0, 1.0, 1.0]) True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) @@ -443,6 +468,25 @@ class RandomForestClassificationModel(TreeEnsembleModels): .. versionadded:: 1.4.0 """ + @property + @since("2.0.0") + def featureImportances(self): + """ + Estimate of the importance of each feature. + + This generalizes the idea of "Gini" importance to other losses, + following the explanation of Gini importance from "Random Forests" documentation + by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn. + + This feature importance is calculated as follows: + - Average over trees: + - importance(feature j) = sum (over nodes which split on feature j) of the gain, + where gain is scaled by the number of instances passing through node + - Normalize importances for tree to sum to 1. + - Normalize feature importance vector to sum to 1. + """ + return self._call_java("featureImportances") + @inherit_doc class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, |