aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-07-07 08:58:08 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-07 08:58:08 -0700
commit1dbc4a155f3697a3973909806be42a1be6017d12 (patch)
treeaef8a830fbeca56fcfa2ca5b050959f62b2c5c25 /python/pyspark/ml/regression.py
parent0a63d7ab8a58d3e48d01740729a7832f1834efe8 (diff)
downloadspark-1dbc4a155f3697a3973909806be42a1be6017d12.tar.gz
spark-1dbc4a155f3697a3973909806be42a1be6017d12.tar.bz2
spark-1dbc4a155f3697a3973909806be42a1be6017d12.zip
[SPARK-8711] [ML] Add additional methods to PySpark ML tree models
Add numNodes and depth to treeModels, add treeWeights to ensemble Models. Add __repr__ to all models. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7095 from MechCoder/missing_methods_tree and squashes the following commits: 23b08be [MechCoder] private [spark] 38a0860 [MechCoder] rename pyTreeWeights to javaTreeWeights 6d16ad8 [MechCoder] Fix Python 3 Error 47d7023 [MechCoder] Use np.allclose and treeEnsembleModel -> TreeEnsembleMethods 819098c [MechCoder] [SPARK-8711] [ML] Add additional methods ot PySpark ML tree models
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py46
1 files changed, 43 insertions, 3 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index b139e27372..44f60a7695 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -172,6 +172,10 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> dt = DecisionTreeRegressor(maxDepth=2)
>>> model = dt.fit(df)
+ >>> model.depth
+ 1
+ >>> model.numNodes
+ 3
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@@ -239,7 +243,37 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return self.getOrDefault(self.impurity)
-class DecisionTreeRegressionModel(JavaModel):
+@inherit_doc
+class DecisionTreeModel(JavaModel):
+
+ @property
+ def numNodes(self):
+ """Return number of nodes of the decision tree."""
+ return self._call_java("numNodes")
+
+ @property
+ def depth(self):
+ """Return depth of the decision tree."""
+ return self._call_java("depth")
+
+ def __repr__(self):
+ return self._call_java("toString")
+
+
+@inherit_doc
+class TreeEnsembleModels(JavaModel):
+
+ @property
+ def treeWeights(self):
+ """Return the weights for each tree"""
+ return list(self._call_java("javaTreeWeights"))
+
+ def __repr__(self):
+ return self._call_java("toString")
+
+
+@inherit_doc
+class DecisionTreeRegressionModel(DecisionTreeModel):
"""
Model fitted by DecisionTreeRegressor.
"""
@@ -253,12 +287,15 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
learning algorithm for regression.
It supports both continuous and categorical features.
+ >>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
>>> model = rf.fit(df)
+ >>> allclose(model.treeWeights, [1.0, 1.0])
+ True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@@ -389,7 +426,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return self.getOrDefault(self.featureSubsetStrategy)
-class RandomForestRegressionModel(JavaModel):
+class RandomForestRegressionModel(TreeEnsembleModels):
"""
Model fitted by RandomForestRegressor.
"""
@@ -403,12 +440,15 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
learning algorithm for regression.
It supports both continuous and categorical features.
+ >>> from numpy import allclose
>>> from pyspark.mllib.linalg import Vectors
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
>>> gbt = GBTRegressor(maxIter=5, maxDepth=2)
>>> model = gbt.fit(df)
+ >>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
+ True
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
0.0
@@ -518,7 +558,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
return self.getOrDefault(self.stepSize)
-class GBTRegressionModel(JavaModel):
+class GBTRegressionModel(TreeEnsembleModels):
"""
Model fitted by GBTRegressor.
"""