aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWeichenXu <WeichenXu123@outlook.com>2016-09-22 04:35:54 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-09-22 04:35:54 -0700
commit72d9fba26c19aae73116fd0d00b566967934c6fc (patch)
treeec8a590dd79baa81f7949c471634ea7443b7f526
parent646f383465c123062cbcce288a127e23984c7c7f (diff)
downloadspark-72d9fba26c19aae73116fd0d00b566967934c6fc.tar.gz
spark-72d9fba26c19aae73116fd0d00b566967934c6fc.tar.bz2
spark-72d9fba26c19aae73116fd0d00b566967934c6fc.zip
[SPARK-17281][ML][MLLIB] Add treeAggregateDepth parameter for AFTSurvivalRegression
## What changes were proposed in this pull request? Add treeAggregateDepth parameter for AFTSurvivalRegression to keep consistent with LiR/LoR. ## How was this patch tested? Existing tests. Author: WeichenXu <WeichenXu123@outlook.com> Closes #14851 from WeichenXu123/add_treeAggregate_param_for_survival_regression.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala24
-rw-r--r--python/pyspark/ml/regression.py11
2 files changed, 25 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 3179f4882f..9d5ba99978 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -46,7 +46,7 @@ import org.apache.spark.storage.StorageLevel
*/
private[regression] trait AFTSurvivalRegressionParams extends Params
with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
- with HasTol with HasFitIntercept with Logging {
+ with HasTol with HasFitIntercept with HasAggregationDepth with Logging {
/**
* Param for censor column name.
@@ -184,6 +184,17 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
setDefault(tol -> 1E-6)
/**
+ * Suggested depth for treeAggregate (>= 2).
+ * If the dimensions of features or the number of partitions are large,
+ * this param could be adjusted to a larger size.
+ * Default is 2.
+ * @group expertSetParam
+ */
+ @Since("2.1.0")
+ def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
+ setDefault(aggregationDepth -> 2)
+
+ /**
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
* and put it in an RDD with strong types.
*/
@@ -207,7 +218,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
c1.merge(c2)
}
- instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp)
+ instances.treeAggregate(
+ new MultivariateOnlineSummarizer
+ )(seqOp, combOp, $(aggregationDepth))
}
val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
@@ -222,7 +235,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
val bcFeaturesStd = instances.context.broadcast(featuresStd)
- val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd)
+ val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd, $(aggregationDepth))
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
/*
@@ -591,7 +604,8 @@ private class AFTAggregator(
private class AFTCostFun(
data: RDD[AFTPoint],
fitIntercept: Boolean,
- bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] {
+ bcFeaturesStd: Broadcast[Array[Double]],
+ aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
@@ -604,7 +618,7 @@ private class AFTCostFun(
},
combOp = (c1, c2) => (c1, c2) match {
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
- })
+ }, depth = aggregationDepth)
bcParameters.destroy(blocking = false)
(aftAggregator.loss, aftAggregator.gradient)
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 19afc723bb..55d38033ef 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -1088,7 +1088,8 @@ class GBTRegressionModel(TreeEnsembleModel, JavaPredictionModel, JavaMLWritable,
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable):
+ HasFitIntercept, HasMaxIter, HasTol, HasAggregationDepth,
+ JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental
@@ -1153,12 +1154,12 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
- quantilesCol=None):
+ quantilesCol=None, aggregationDepth=2):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
- quantilesCol=None)
+ quantilesCol=None, aggregationDepth=2)
"""
super(AFTSurvivalRegression, self).__init__()
self._java_obj = self._new_java_obj(
@@ -1174,12 +1175,12 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor",
quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]),
- quantilesCol=None):
+ quantilesCol=None, aggregationDepth=2):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \
quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \
- quantilesCol=None):
+ quantilesCol=None, aggregationDepth=2):
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)