aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-06 10:52:25 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-06 10:52:25 -0800
commit3aa3488225af12a77da3ba807906bc6a461ef11c (patch)
treea817cf07b4cbb89b681d80819a75038610d566a4 /python/pyspark/ml/regression.py
parent95eb65163391b9e910277a948b72efccf6136e0c (diff)
downloadspark-3aa3488225af12a77da3ba807906bc6a461ef11c.tar.gz
spark-3aa3488225af12a77da3ba807906bc6a461ef11c.tar.bz2
spark-3aa3488225af12a77da3ba807906bc6a461ef11c.zip
[SPARK-11815][ML][PYSPARK] PySpark DecisionTreeClassifier & DecisionTreeRegressor should support setSeed
PySpark ```DecisionTreeClassifier``` & ```DecisionTreeRegressor``` should support ```setSeed``` like what we do at Scala side. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9807 from yanboliang/spark-11815.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py14
1 files changed, 9 insertions, 5 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index a0bb8ceed8..401bac0223 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -386,7 +386,8 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval):
+ DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
+ HasSeed):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for regression.
@@ -415,11 +416,13 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"):
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
+ seed=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
+ impurity="variance", seed=None)
"""
super(DecisionTreeRegressor, self).__init__()
self._java_obj = self._new_java_obj(
@@ -435,11 +438,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- impurity="variance"):
+ impurity="variance", seed=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
+ impurity="variance", seed=None)
Sets params for the DecisionTreeRegressor.
"""
kwargs = self.setParams._input_kwargs