aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py44
1 files changed, 44 insertions, 0 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 29d1d203f2..ec8834a89e 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -285,6 +285,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
3
>>> model.depth
1
+ >>> model.featureImportances
+ SparseVector(1, {0: 1.0})
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> result = model.transform(test0).head()
>>> result.prediction
@@ -352,6 +354,27 @@ class DecisionTreeClassificationModel(DecisionTreeModel):
.. versionadded:: 1.4.0
"""
+ @property
+ @since("2.0.0")
+ def featureImportances(self):
+ """
+ Estimate of the importance of each feature.
+
+ This generalizes the idea of "Gini" importance to other losses,
+ following the explanation of Gini importance from "Random Forests" documentation
+ by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+
+ This feature importance is calculated as follows:
+ - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ where gain is scaled by the number of instances passing through node
+ - Normalize importances for tree to sum to 1.
+
+ Note: Feature importance for single decision trees can have high variance due to
+ correlated predictor variables. Consider using a :class:`RandomForestClassifier`
+ to determine feature importance instead.
+ """
+ return self._call_java("featureImportances")
+
@inherit_doc
class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasSeed,
@@ -375,6 +398,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> td = si_model.transform(df)
>>> rf = RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
+ >>> model.featureImportances
+ SparseVector(1, {0: 1.0})
>>> allclose(model.treeWeights, [1.0, 1.0, 1.0])
True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
@@ -443,6 +468,25 @@ class RandomForestClassificationModel(TreeEnsembleModels):
.. versionadded:: 1.4.0
"""
+ @property
+ @since("2.0.0")
+ def featureImportances(self):
+ """
+ Estimate of the importance of each feature.
+
+ This generalizes the idea of "Gini" importance to other losses,
+ following the explanation of Gini importance from "Random Forests" documentation
+ by Leo Breiman and Adele Cutler, and following the implementation from scikit-learn.
+
+ This feature importance is calculated as follows:
+ - Average over trees:
+ - importance(feature j) = sum (over nodes which split on feature j) of the gain,
+ where gain is scaled by the number of instances passing through node
+ - Normalize importances for tree to sum to 1.
+ - Normalize feature importance vector to sum to 1.
+ """
+ return self._call_java("featureImportances")
+
@inherit_doc
class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,