diff options
author | Narine Kokhlikyan <narine.kokhlikyan@gmail.com> | 2016-02-22 17:26:32 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-02-22 17:26:32 -0800 |
commit | 33ef3aa7eabbe323620eb77fa94a53996ed0251d (patch) | |
tree | 610e341f05d575eb01bb18fc896189ff4ebf042c /mllib/src | |
parent | 02b1fefffb00d50c1076a26f2f3f41f3c1fa0001 (diff) | |
download | spark-33ef3aa7eabbe323620eb77fa94a53996ed0251d.tar.gz spark-33ef3aa7eabbe323620eb77fa94a53996ed0251d.tar.bz2 spark-33ef3aa7eabbe323620eb77fa94a53996ed0251d.zip |
[SPARK-13295][ ML, MLLIB ] AFTSurvivalRegression.AFTAggregator improvements - avoid creating new instances of arrays/vectors for each record
As also mentioned/marked by TODO in AFTAggregator.AFTAggregator.add(data: AFTPoint) method a new array is being created for intercept value and it is being concatenated
with another array which contains the betas, the resulted Array is being converted into a Dense vector which in its turn is being converted into breeze vector.
This is expensive and not necessarily beautiful.
I've tried to solve above mentioned problem by simple algebraic decompositions - keeping and treating intercept independently.
Please let me know what do you think and if you have any questions.
Thanks,
Narine
Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com>
Closes #11179 from NarineK/survivaloptim.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 32 |
1 files changed, 17 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index e8a1ff2278..1e5b4cb83c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -437,23 +437,25 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel] private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) extends Serializable { - // beta is the intercept and regression coefficients to the covariates - private val beta = parameters.slice(1, parameters.length) + // the regression coefficients to the covariates + private val coefficients = parameters.slice(2, parameters.length) + private val intercept = parameters.valueAt(1) // sigma is the scale parameter of the AFT model private val sigma = math.exp(parameters(0)) private var totalCnt: Long = 0L private var lossSum = 0.0 - private var gradientBetaSum = BDV.zeros[Double](beta.length) + private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length) + private var gradientInterceptSum = 0.0 private var gradientLogSigmaSum = 0.0 def count: Long = totalCnt def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt - // Here we optimize loss function over beta and log(sigma) + // Here we optimize loss function over coefficients, intercept and log(sigma) def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)), - gradientBetaSum/totalCnt.toDouble) + BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble) /** * Add a new training data to this AFTAggregator, and update the loss and gradient @@ -464,15 +466,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) */ def add(data: AFTPoint): this.type = { - // TODO: Don't create a new xi vector each time. - val xi = if (fitIntercept) { - Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze - } else { - Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze - } + val interceptFlag = if (fitIntercept) 1.0 else 0.0 + + val xi = data.features.toBreeze val ti = data.label val delta = data.censor - val epsilon = (math.log(ti) - beta.dot(xi)) / sigma + val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma lossSum += math.log(sigma) * delta lossSum += (math.exp(epsilon) - delta * epsilon) @@ -481,8 +480,10 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) assert(!lossSum.isInfinity, s"AFTAggregator loss sum is infinity. Error for unknown reason.") - gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma - gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon + val deltaMinusExpEps = delta - math.exp(epsilon) + gradientCoefficientSum += xi * deltaMinusExpEps / sigma + gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma + gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon totalCnt += 1 this @@ -501,7 +502,8 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean) totalCnt += other.totalCnt lossSum += other.lossSum - gradientBetaSum += other.gradientBetaSum + gradientCoefficientSum += other.gradientCoefficientSum + gradientInterceptSum += other.gradientInterceptSum gradientLogSigmaSum += other.gradientLogSigmaSum } this |