From 3aa3488225af12a77da3ba807906bc6a461ef11c Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 6 Jan 2016 10:52:25 -0800 Subject: [SPARK-11815][ML][PYSPARK] PySpark DecisionTreeClassifier & DecisionTreeRegressor should support setSeed PySpark ```DecisionTreeClassifier``` & ```DecisionTreeRegressor``` should support ```setSeed``` like what we do at Scala side. Author: Yanbo Liang Closes #9807 from yanboliang/spark-11815. --- python/pyspark/ml/regression.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) (limited to 'python/pyspark/ml/regression.py') diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index a0bb8ceed8..401bac0223 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -386,7 +386,8 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval): + DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval, + HasSeed): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for regression. @@ -415,11 +416,13 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", + seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", seed=None) """ super(DecisionTreeRegressor, self).__init__() self._java_obj = self._new_java_obj( @@ -435,11 +438,12 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="variance"): + impurity="variance", seed=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", seed=None) Sets params for the DecisionTreeRegressor. """ kwargs = self.setParams._input_kwargs -- cgit v1.2.3