aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-06 10:52:25 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-06 10:52:25 -0800
commit3aa3488225af12a77da3ba807906bc6a461ef11c (patch)
treea817cf07b4cbb89b681d80819a75038610d566a4 /python
parent95eb65163391b9e910277a948b72efccf6136e0c (diff)
downloadspark-3aa3488225af12a77da3ba807906bc6a461ef11c.tar.gz
spark-3aa3488225af12a77da3ba807906bc6a461ef11c.tar.bz2
spark-3aa3488225af12a77da3ba807906bc6a461ef11c.zip
[SPARK-11815][ML][PYSPARK] PySpark DecisionTreeClassifier & DecisionTreeRegressor should support setSeed
PySpark ```DecisionTreeClassifier``` & ```DecisionTreeRegressor``` should support ```setSeed``` like what we do at Scala side. Author: Yanbo Liang <ybliang8@gmail.com> Closes #9807 from yanboliang/spark-11815.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py13
-rw-r--r--python/pyspark/ml/regression.py14
2 files changed, 17 insertions, 10 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 5599b8f3ec..265c6a14f1 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -273,7 +273,7 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams,
- TreeClassifierParams, HasCheckpointInterval):
+ TreeClassifierParams, HasCheckpointInterval, HasSeed):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for classification.
@@ -313,12 +313,14 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"):
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
+ seed=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
+ seed=None)
"""
super(DecisionTreeClassifier, self).__init__()
self._java_obj = self._new_java_obj(
@@ -335,12 +337,13 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
probabilityCol="probability", rawPredictionCol="rawPrediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- impurity="gini"):
+ impurity="gini", seed=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
probabilityCol="probability", rawPredictionCol="rawPrediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini")
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
+ seed=None)
Sets params for the DecisionTreeClassifier.
"""
kwargs = self.setParams._input_kwargs
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index a0bb8ceed8..401bac0223 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -386,7 +386,8 @@ class GBTParams(TreeEnsembleParams):
@inherit_doc
class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval):
+ DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval,
+ HasSeed):
"""
`http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree`
learning algorithm for regression.
@@ -415,11 +416,13 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"):
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
+ seed=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
+ impurity="variance", seed=None)
"""
super(DecisionTreeRegressor, self).__init__()
self._java_obj = self._new_java_obj(
@@ -435,11 +438,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- impurity="variance"):
+ impurity="variance", seed=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance")
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
+ impurity="variance", seed=None)
Sets params for the DecisionTreeRegressor.
"""
kwargs = self.setParams._input_kwargs