aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index fdeccf822c..850d775db0 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -520,7 +520,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> si_model = stringIndexer.fit(df)
>>> td = si_model.transform(df)
- >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed")
+ >>> gbt = GBTClassifier(maxIter=5, maxDepth=2, labelCol="indexed", seed=42)
>>> model = gbt.fit(td)
>>> allclose(model.treeWeights, [1.0, 0.1, 0.1, 0.1, 0.1])
True
@@ -543,19 +543,19 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic",
- maxIter=20, stepSize=0.1):
+ maxIter=20, stepSize=0.1, seed=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
- lossType="logistic", maxIter=20, stepSize=0.1)
+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
"""
super(GBTClassifier, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.classification.GBTClassifier", self.uid)
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- lossType="logistic", maxIter=20, stepSize=0.1)
+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -564,12 +564,12 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10,
- lossType="logistic", maxIter=20, stepSize=0.1):
+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
- lossType="logistic", maxIter=20, stepSize=0.1)
+ lossType="logistic", maxIter=20, stepSize=0.1, seed=None)
Sets params for Gradient Boosted Tree Classification.
"""
kwargs = self.setParams._input_kwargs