diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-07-07 08:58:08 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-07-07 08:58:08 -0700 |
commit | 1dbc4a155f3697a3973909806be42a1be6017d12 (patch) | |
tree | aef8a830fbeca56fcfa2ca5b050959f62b2c5c25 /python | |
parent | 0a63d7ab8a58d3e48d01740729a7832f1834efe8 (diff) | |
download | spark-1dbc4a155f3697a3973909806be42a1be6017d12.tar.gz spark-1dbc4a155f3697a3973909806be42a1be6017d12.tar.bz2 spark-1dbc4a155f3697a3973909806be42a1be6017d12.zip |
[SPARK-8711] [ML] Add additional methods to PySpark ML tree models
Add numNodes and depth to treeModels, add treeWeights to ensemble Models.
Add __repr__ to all models.
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #7095 from MechCoder/missing_methods_tree and squashes the following commits:
23b08be [MechCoder] private [spark]
38a0860 [MechCoder] rename pyTreeWeights to javaTreeWeights
6d16ad8 [MechCoder] Fix Python 3 Error
47d7023 [MechCoder] Use np.allclose and treeEnsembleModel -> TreeEnsembleMethods
819098c [MechCoder] [SPARK-8711] [ML] Add additional methods ot PySpark ML tree models
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/ml/classification.py | 20 | ||||
-rw-r--r-- | python/pyspark/ml/regression.py | 46 |
2 files changed, 59 insertions, 7 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 7abbde8b26..89117e4928 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -18,7 +18,8 @@ from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel from pyspark.ml.param.shared import * -from pyspark.ml.regression import RandomForestParams +from pyspark.ml.regression import ( + RandomForestParams, DecisionTreeModel, TreeEnsembleModels) from pyspark.mllib.common import inherit_doc @@ -202,6 +203,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> td = si_model.transform(df) >>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed") >>> model = dt.fit(td) + >>> model.numNodes + 3 + >>> model.depth + 1 >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -269,7 +274,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred return self.getOrDefault(self.impurity) -class DecisionTreeClassificationModel(JavaModel): +@inherit_doc +class DecisionTreeClassificationModel(DecisionTreeModel): """ Model fitted by DecisionTreeClassifier. """ @@ -284,6 +290,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred It supports both binary and multiclass labels, as well as both continuous and categorical features. + >>> from numpy import allclose >>> from pyspark.mllib.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer >>> df = sqlContext.createDataFrame([ @@ -294,6 +301,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred >>> td = si_model.transform(df) >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42) >>> model = rf.fit(td) + >>> allclose(model.treeWeights, [1.0, 1.0]) + True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -423,7 +432,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred return self.getOrDefault(self.featureSubsetStrategy) -class RandomForestClassificationModel(JavaModel): +class RandomForestClassificationModel(TreeEnsembleModels): """ Model fitted by RandomForestClassifier. """ @@ -438,6 +447,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol It supports binary labels, as well as both continuous and categorical features. Note: Multiclass labels are not currently supported. + >>> from numpy import allclose >>> from pyspark.mllib.linalg import Vectors >>> from pyspark.ml.feature import StringIndexer >>> df = sqlContext.createDataFrame([ @@ -448,6 +458,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol >>> td = si_model.transform(df) >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed") >>> model = gbt.fit(td) + >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) + True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -558,7 +570,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol return self.getOrDefault(self.stepSize) -class GBTClassificationModel(JavaModel): +class GBTClassificationModel(TreeEnsembleModels): """ Model fitted by GBTClassifier. """ diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index b139e27372..44f60a7695 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -172,6 +172,10 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> dt = DecisionTreeRegressor(maxDepth=2) >>> model = dt.fit(df) + >>> model.depth + 1 + >>> model.numNodes + 3 >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -239,7 +243,37 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi return self.getOrDefault(self.impurity) -class DecisionTreeRegressionModel(JavaModel): +@inherit_doc +class DecisionTreeModel(JavaModel): + + @property + def numNodes(self): + """Return number of nodes of the decision tree.""" + return self._call_java("numNodes") + + @property + def depth(self): + """Return depth of the decision tree.""" + return self._call_java("depth") + + def __repr__(self): + return self._call_java("toString") + + +@inherit_doc +class TreeEnsembleModels(JavaModel): + + @property + def treeWeights(self): + """Return the weights for each tree""" + return list(self._call_java("javaTreeWeights")) + + def __repr__(self): + return self._call_java("toString") + + +@inherit_doc +class DecisionTreeRegressionModel(DecisionTreeModel): """ Model fitted by DecisionTreeRegressor. """ @@ -253,12 +287,15 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi learning algorithm for regression. It supports both continuous and categorical features. + >>> from numpy import allclose >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42) >>> model = rf.fit(df) + >>> allclose(model.treeWeights, [1.0, 1.0]) + True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -389,7 +426,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi return self.getOrDefault(self.featureSubsetStrategy) -class RandomForestRegressionModel(JavaModel): +class RandomForestRegressionModel(TreeEnsembleModels): """ Model fitted by RandomForestRegressor. """ @@ -403,12 +440,15 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, learning algorithm for regression. It supports both continuous and categorical features. + >>> from numpy import allclose >>> from pyspark.mllib.linalg import Vectors >>> df = sqlContext.createDataFrame([ ... (1.0, Vectors.dense(1.0)), ... (0.0, Vectors.sparse(1, [], []))], ["label", "features"]) >>> gbt = GBTRegressor(maxIter=5, maxDepth=2) >>> model = gbt.fit(df) + >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1]) + True >>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"]) >>> model.transform(test0).head().prediction 0.0 @@ -518,7 +558,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, return self.getOrDefault(self.stepSize) -class GBTRegressionModel(JavaModel): +class GBTRegressionModel(TreeEnsembleModels): """ Model fitted by GBTRegressor. """ |