diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-08-09 03:39:57 -0700 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-08-09 03:39:57 -0700 |
commit | 182e11904bf2093c2faa57894a1c4bb11d872596 (patch) | |
tree | ed32964fc35e5626ccc698de03a67d30d2e3c0d0 /mllib | |
parent | 511f52f8423e151b0d0133baf040d34a0af3d422 (diff) | |
download | spark-182e11904bf2093c2faa57894a1c4bb11d872596.tar.gz spark-182e11904bf2093c2faa57894a1c4bb11d872596.tar.bz2 spark-182e11904bf2093c2faa57894a1c4bb11d872596.zip |
[SPARK-16933][ML] Fix AFTAggregator in AFTSurvivalRegression serializes unnecessary data.
## What changes were proposed in this pull request?
Similar to ```LeastSquaresAggregator``` in #14109, ```AFTAggregator``` used for ```AFTSurvivalRegression``` ends up serializing the ```parameters``` and ```featuresStd```, which is not necessary and can cause performance issues for high dimensional data. This patch removes this serialization. This PR is highly inspired by #14109.
## How was this patch tested?
I tested this locally and verified the serialization reduction.
Before patch
![image](https://cloud.githubusercontent.com/assets/1962026/17512035/abb93f04-5dda-11e6-97d3-8ae6b61a0dfd.png)
After patch
![image](https://cloud.githubusercontent.com/assets/1962026/17512024/9e0dc44c-5dda-11e6-93d0-6e130ba0d6aa.png)
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #14519 from yanboliang/spark-16933.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 47 |
1 files changed, 29 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index be234f7fea..3179f4882f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT} @@ -219,7 +220,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S "columns. This behavior is different from R survival::survreg.") } - val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd) + val bcFeaturesStd = instances.context.broadcast(featuresStd) + + val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd) val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) /* @@ -247,6 +250,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S state.x.toArray.clone() } + bcFeaturesStd.destroy(blocking = false) if (handlePersistence) instances.unpersist() val rawCoefficients = parameters.slice(2, parameters.length) @@ -478,26 +482,29 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] * $$ * </blockquote></p> * - * @param parameters including three part: The log of scale parameter, the intercept and - * regression coefficients corresponding to the features. + * @param bcParameters The broadcasted value includes three part: The log of scale parameter, + * the intercept and regression coefficients corresponding to the features. * @param fitIntercept Whether to fit an intercept term. - * @param featuresStd The standard deviation values of the features. + * @param bcFeaturesStd The broadcast standard deviation values of the features. */ private class AFTAggregator( - parameters: BDV[Double], + bcParameters: Broadcast[BDV[Double]], fitIntercept: Boolean, - featuresStd: Array[Double]) extends Serializable { + bcFeaturesStd: Broadcast[Array[Double]]) extends Serializable { + private val length = bcParameters.value.length + // make transient so we do not serialize between aggregation stages + @transient private lazy val parameters = bcParameters.value // the regression coefficients to the covariates - private val coefficients = parameters.slice(2, parameters.length) - private val intercept = parameters(1) + @transient private lazy val coefficients = parameters.slice(2, length) + @transient private lazy val intercept = parameters(1) // sigma is the scale parameter of the AFT model - private val sigma = math.exp(parameters(0)) + @transient private lazy val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 // Here we optimize loss function over log(sigma), intercept and coefficients - private val gradientSumArray = Array.ofDim[Double](parameters.length) + private val gradientSumArray = Array.ofDim[Double](length) def count: Long = totalCnt def loss: Double = { @@ -524,11 +531,13 @@ private class AFTAggregator( val ti = data.label val delta = data.censor + val localFeaturesStd = bcFeaturesStd.value + val margin = { var sum = 0.0 xi.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - sum += coefficients(index) * (value / featuresStd(index)) + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + sum += coefficients(index) * (value / localFeaturesStd(index)) } } sum + intercept @@ -542,8 +551,8 @@ private class AFTAggregator( gradientSumArray(0) += delta + multiplier * sigma * epsilon gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 } xi.foreachActive { (index, value) => - if (featuresStd(index) != 0.0 && value != 0.0) { - gradientSumArray(index + 2) += multiplier * (value / featuresStd(index)) + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + gradientSumArray(index + 2) += multiplier * (value / localFeaturesStd(index)) } } @@ -565,8 +574,7 @@ private class AFTAggregator( lossSum += other.lossSum var i = 0 - val len = this.gradientSumArray.length - while (i < len) { + while (i < length) { this.gradientSumArray(i) += other.gradientSumArray(i) i += 1 } @@ -583,12 +591,14 @@ private class AFTAggregator( private class AFTCostFun( data: RDD[AFTPoint], fitIntercept: Boolean, - featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] { + bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] { override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = { + val bcParameters = data.context.broadcast(parameters) + val aftAggregator = data.treeAggregate( - new AFTAggregator(parameters, fitIntercept, featuresStd))( + new AFTAggregator(bcParameters, fitIntercept, bcFeaturesStd))( seqOp = (c, v) => (c, v) match { case (aggregator, instance) => aggregator.add(instance) }, @@ -596,6 +606,7 @@ private class AFTCostFun( case (aggregator1, aggregator2) => aggregator1.merge(aggregator2) }) + bcParameters.destroy(blocking = false) (aftAggregator.loss, aftAggregator.gradient) } } |