diff options
author | WeichenXu <WeichenXu123@outlook.com> | 2016-09-22 04:35:54 -0700 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-09-22 04:35:54 -0700 |
commit | 72d9fba26c19aae73116fd0d00b566967934c6fc (patch) | |
tree | ec8a590dd79baa81f7949c471634ea7443b7f526 /python/pyspark/ml | |
parent | 646f383465c123062cbcce288a127e23984c7c7f (diff) | |
download | spark-72d9fba26c19aae73116fd0d00b566967934c6fc.tar.gz spark-72d9fba26c19aae73116fd0d00b566967934c6fc.tar.bz2 spark-72d9fba26c19aae73116fd0d00b566967934c6fc.zip |
[SPARK-17281][ML][MLLIB] Add treeAggregateDepth parameter for AFTSurvivalRegression
## What changes were proposed in this pull request?
Add treeAggregateDepth parameter for AFTSurvivalRegression to keep consistent with LiR/LoR.
## How was this patch tested?
Existing tests.
Author: WeichenXu <WeichenXu123@outlook.com>
Closes #14851 from WeichenXu123/add_treeAggregate_param_for_survival_regression.
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r-- | python/pyspark/ml/regression.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 19afc723bb..55d38033ef 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1088,7 +1088,8 @@ class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable, @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable): + HasFitIntercept, HasMaxIter, HasTol, HasAggregationDepth, + JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -1153,12 +1154,12 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), - quantilesCol=None): + quantilesCol=None, aggregationDepth=2): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None) + quantilesCol=None, aggregationDepth=2) """ super(AFTSurvivalRegression, self).__init__() self._java_obj = self._new_java_obj( @@ -1174,12 +1175,12 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), - quantilesCol=None): + quantilesCol=None, aggregationDepth=2): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ - quantilesCol=None): + quantilesCol=None, aggregationDepth=2): """ kwargs = self.setParams._input_kwargs return self._set(**kwargs) |