aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-07-07 08:58:08 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-07 08:58:08 -0700
commit1dbc4a155f3697a3973909806be42a1be6017d12 (patch)
treeaef8a830fbeca56fcfa2ca5b050959f62b2c5c25 /python/pyspark/ml/classification.py
parent0a63d7ab8a58d3e48d01740729a7832f1834efe8 (diff)
downloadspark-1dbc4a155f3697a3973909806be42a1be6017d12.tar.gz
spark-1dbc4a155f3697a3973909806be42a1be6017d12.tar.bz2
spark-1dbc4a155f3697a3973909806be42a1be6017d12.zip
[SPARK-8711] [ML] Add additional methods to PySpark ML tree models
Add numNodes and depth to treeModels, add treeWeights to ensemble Models. Add __repr__ to all models. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7095 from MechCoder/missing_methods_tree and squashes the following commits: 23b08be [MechCoder] private [spark] 38a0860 [MechCoder] rename pyTreeWeights to javaTreeWeights 6d16ad8 [MechCoder] Fix Python 3 Error 47d7023 [MechCoder] Use np.allclose and treeEnsembleModel -> TreeEnsembleMethods 819098c [MechCoder] [SPARK-8711] [ML] Add additional methods ot PySpark ML tree models
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py20
1 files changed, 16 insertions, 4 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 7abbde8b26..89117e4928 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -18,7 +18,8 @@
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import *
-from pyspark.ml.regression import RandomForestParams
+from pyspark.ml.regression import (
+ RandomForestParams, DecisionTreeModel, TreeEnsembleModels)
from pyspark.mllib.common import inherit_doc
@@ -202,6 +203,10 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> td = si_model.transform(df)
>>> dt = DecisionTreeClassifier(maxDepth=2, labelCol="indexed")
>>> model = dt.fit(td)
+ >>> model.numNodes
+ 3
+ >>> model.depth
+ 1
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@@ -269,7 +274,8 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
return self.getOrDefault(self.impurity)
-class DecisionTreeClassificationModel(JavaModel):
+@inherit_doc
+class DecisionTreeClassificationModel(DecisionTreeModel):
"""
Model fitted by DecisionTreeClassifier.
"""
@@ -284,6 +290,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
It supports both binary and multiclass labels, as well as both continuous and categorical
features.
+ >>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = sqlContext.createDataFrame([
@@ -294,6 +301,8 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> td = si_model.transform(df)
>>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
+ >>> allclose(model.treeWeights, [1.0, 1.0])
+ True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@@ -423,7 +432,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
return self.getOrDefault(self.featureSubsetStrategy)
-class RandomForestClassificationModel(JavaModel):
+class RandomForestClassificationModel(TreeEnsembleModels):
"""
Model fitted by RandomForestClassifier.
"""
@@ -438,6 +447,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
It supports binary labels, as well as both continuous and categorical features.
Note: Multiclass labels are not currently supported.
+ >>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> from pyspark.ml.feature import StringIndexer
>>> df = sqlContext.createDataFrame([
@@ -448,6 +458,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> td = si_model.transform(df)
>>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
>>> model = gbt.fit(td)
+ >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
+ True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@@ -558,7 +570,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
return self.getOrDefault(self.stepSize)
-class GBTClassificationModel(JavaModel):
+class GBTClassificationModel(TreeEnsembleModels):
"""
Model fitted by GBTClassifier.
"""