From 182e11904bf2093c2faa57894a1c4bb11d872596 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Tue, 9 Aug 2016 03:39:57 -0700 Subject: [SPARK-16933][ML] Fix AFTAggregator in AFTSurvivalRegression serializes unnecessary data. ## What changes were proposed in this pull request? Similar to ```LeastSquaresAggregator``` in #14109, ```AFTAggregator``` used for ```AFTSurvivalRegression``` ends up serializing the ```parameters``` and ```featuresStd```, which is not necessary and can cause performance issues for high dimensional data. This patch removes this serialization. This PR is highly inspired by #14109. ## How was this patch tested? I tested this locally and verified the serialization reduction. Before patch ![image](https://cloud.githubusercontent.com/assets/1962026/17512035/abb93f04-5dda-11e6-97d3-8ae6b61a0dfd.png) After patch ![image](https://cloud.githubusercontent.com/assets/1962026/17512024/9e0dc44c-5dda-11e6-93d0-6e130ba0d6aa.png) Author: Yanbo Liang Closes #14519 from yanboliang/spark-16933. --- .../ml/regression/AFTSurvivalRegression.scala | 47 +++++++++++++--------- 1 file changed, 29 insertions(+), 18 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index be234f7fea..3179f4882f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} @@ -219,7 +220,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S "columns. This behavior is different from R survival::survreg.") } - val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd) + val bcFeaturesStd = instances.context.broadcast(featuresStd) + + val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) /* @@ -247,6 +250,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S state.x.toArray.clone() } + bcFeaturesStd.destroy(blocking = false) if (handlePersistence) instances.unpersist() val rawCoefficients = parameters.slice(2, parameters.length) @@ -478,26 +482,29 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] * $$ *

* - * @param parameters including three part: The log of scale parameter, the intercept and - * regression coefficients corresponding to the features. + * @param bcParameters The broadcasted value includes three part: The log of scale parameter, + * the intercept and regression coefficients corresponding to the features. * @param fitIntercept Whether to fit an intercept term. - * @param featuresStd The standard deviation values of the features. + * @param bcFeaturesStd The broadcast standard deviation values of the features. */ private class AFTAggregator( - parameters: BDV[Double], + bcParameters: Broadcast[BDV[Double]], fitIntercept: Boolean, - featuresStd: Array[Double]) extends Serializable { + bcFeaturesStd: Broadcast[Array[Double]]) extends Serializable { + private val length = bcParameters.value.length + // make transient so we do not serialize between aggregation stages + @transient private lazy val parameters = bcParameters.value // the regression coefficients to the covariates - private val coefficients = parameters.slice(2, parameters.length) - private val intercept = parameters(1) + @transient private lazy val coefficients = parameters.slice(2, length) + @transient private lazy val intercept = parameters(1) // sigma is the scale parameter of the AFT model - private val sigma = math.exp(parameters(0)) + @transient private lazy val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 // Here we optimize loss function over log(sigma), intercept and coefficients - private val gradientSumArray = Array.ofDim[Double](parameters.length) + private val gradientSumArray = Array.ofDim[Double](length) def count: Long = totalCnt def loss: Double = { @@ -524,11 +531,13 @@ private class AFTAggregator( val ti = data.label val delta = data.censor + val localFeaturesStd = bcFeaturesStd.value + val margin = { var sum = 0.0 xi.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - sum += coefficients(index) * (value / featuresStd(index)) + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + sum += coefficients(index) * (value / localFeaturesStd(index)) } } sum + intercept @@ -542,8 +551,8 @@ private class AFTAggregator( gradientSumArray(0) += delta + multiplier * sigma * epsilon gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 } xi.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - gradientSumArray(index + 2) += multiplier * (value / featuresStd(index)) + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + gradientSumArray(index + 2) += multiplier * (value / localFeaturesStd(index)) } } @@ -565,8 +574,7 @@ private class AFTAggregator( lossSum += other.lossSum var i = 0 - val len = this.gradientSumArray.length - while (i < len) { + while (i < length) { this.gradientSumArray(i) += other.gradientSumArray(i) i += 1 } @@ -583,12 +591,14 @@ private class AFTAggregator( private class AFTCostFun( data: RDD[AFTPoint], fitIntercept: Boolean, - featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] { + bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] { override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { + val bcParameters = data.context.broadcast(parameters) + val aftAggregator = data.treeAggregate( - new AFTAggregator(parameters, fitIntercept, featuresStd))( + new AFTAggregator(bcParameters, fitIntercept, bcFeaturesStd))( seqOp = (c, v) => (c, v) match { case (aggregator, instance) => aggregator.add(instance) }, @@ -596,6 +606,7 @@ private class AFTCostFun( case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) }) + bcParameters.destroy(blocking = false) (aftAggregator.loss, aftAggregator.gradient) } } -- cgit v1.2.3