diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-07-07 08:58:08 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-07-07 08:58:08 -0700 |
commit | 1dbc4a155f3697a3973909806be42a1be6017d12 (patch) | |
tree | aef8a830fbeca56fcfa2ca5b050959f62b2c5c25 /python/pyspark/ml/classification.py | |
parent | 0a63d7ab8a58d3e48d01740729a7832f1834efe8 (diff) | |
download | spark-1dbc4a155f3697a3973909806be42a1be6017d12.tar.gz spark-1dbc4a155f3697a3973909806be42a1be6017d12.tar.bz2 spark-1dbc4a155f3697a3973909806be42a1be6017d12.zip |
[SPARK-8711] [ML] Add additional methods to PySpark ML tree models
Add numNodes and depth to treeModels, add treeWeights to ensemble Models.
Add __repr__ to all models.
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #7095 from MechCoder/missing_methods_tree and squashes the following commits:
23b08be [MechCoder] private [spark]
38a0860 [MechCoder] rename pyTreeWeights to javaTreeWeights
6d16ad8 [MechCoder] Fix Python 3 Error
47d7023 [MechCoder] Use np.allclose and treeEnsembleModel -> TreeEnsembleMethods
819098c [MechCoder] [SPARK-8711] [ML] Add additional methods ot PySpark ML tree models
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r-- | python/pyspark/ml/classification.py | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 7abbde8b26..89117e4928 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -18,7 +18,8 @@ from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * -from pyspark.ml.regression import RandomForestParams +from pyspark.ml.regression import ( + RandomForestParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.mllib.common import inherit_doc @@ -202,6 +203,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> td = si_model.transform(df) >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed") >>> model = dt.fit(td) + >>> model.numNodes + 3 + >>> model.depth + 1 >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -269,7 +274,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred return self.getOrDefault(self.impurity) -class DecisionTreeClassificationModel(JavaModel): +@inherit_doc +class DecisionTreeClassificationModel(DecisionTreeModel): """ Model fitted by DecisionTreeClassifier. """ @@ -284,6 +290,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred It supports both binary and multiclass labels, as well as both continuous and categorical features. + >>> from numpy import allclose >>> from pyspark.mllib.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer >>> df = sqlContext.createDataFrame([ @@ -294,6 +301,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> td = si_model.transform(df) >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) + >>> allclose(model.treeWeights, [1.0, 1.0]) + True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -423,7 +432,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred return self.getOrDefault(self.featureSubsetStrategy) -class RandomForestClassificationModel(JavaModel): +class RandomForestClassificationModel(TreeEnsembleModels): """ Model fitted by RandomForestClassifier. """ @@ -438,6 +447,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol It supports binary labels, as well as both continuous and categorical features. Note: Multiclass labels are not currently supported. + >>> from numpy import allclose >>> from pyspark.mllib.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer >>> df = sqlContext.createDataFrame([ @@ -448,6 +458,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed") >>> model = gbt.fit(td) + >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) + True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -558,7 +570,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol return self.getOrDefault(self.stepSize) -class GBTClassificationModel(JavaModel): +class GBTClassificationModel(TreeEnsembleModels): """ Model fitted by GBTClassifier. """ |