aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorNarine Kokhlikyan <narine.kokhlikyan@gmail.com>2016-02-22 17:26:32 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-22 17:26:32 -0800
commit33ef3aa7eabbe323620eb77fa94a53996ed0251d (patch)
tree610e341f05d575eb01bb18fc896189ff4ebf042c /mllib
parent02b1fefffb00d50c1076a26f2f3f41f3c1fa0001 (diff)
downloadspark-33ef3aa7eabbe323620eb77fa94a53996ed0251d.tar.gz
spark-33ef3aa7eabbe323620eb77fa94a53996ed0251d.tar.bz2
spark-33ef3aa7eabbe323620eb77fa94a53996ed0251d.zip
[SPARK-13295][ ML, MLLIB ] AFTSurvivalRegression.AFTAggregator improvements - avoid creating new instances of arrays/vectors for each record
As also mentioned/marked by TODO in AFTAggregator.AFTAggregator.add(data: AFTPoint) method a new array is being created for intercept value and it is being concatenated with another array which contains the betas, the resulted Array is being converted into a Dense vector which in its turn is being converted into breeze vector. This is expensive and not necessarily beautiful. I've tried to solve above mentioned problem by simple algebraic decompositions - keeping and treating intercept independently. Please let me know what do you think and if you have any questions. Thanks, Narine Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com> Closes #11179 from NarineK/survivaloptim.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala32
1 files changed, 17 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index e8a1ff2278..1e5b4cb83c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -437,23 +437,25 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
extends Serializable {
- // beta is the intercept and regression coefficients to the covariates
- private val beta = parameters.slice(1, parameters.length)
+ // the regression coefficients to the covariates
+ private val coefficients = parameters.slice(2, parameters.length)
+ private val intercept = parameters.valueAt(1)
// sigma is the scale parameter of the AFT model
private val sigma = math.exp(parameters(0))
private var totalCnt: Long = 0L
private var lossSum = 0.0
- private var gradientBetaSum = BDV.zeros[Double](beta.length)
+ private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length)
+ private var gradientInterceptSum = 0.0
private var gradientLogSigmaSum = 0.0
def count: Long = totalCnt
def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
- // Here we optimize loss function over beta and log(sigma)
+ // Here we optimize loss function over coefficients, intercept and log(sigma)
def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
- gradientBetaSum/totalCnt.toDouble)
+ BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble)
/**
* Add a new training data to this AFTAggregator, and update the loss and gradient
@@ -464,15 +466,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
*/
def add(data: AFTPoint): this.type = {
- // TODO: Don't create a new xi vector each time.
- val xi = if (fitIntercept) {
- Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze
- } else {
- Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze
- }
+ val interceptFlag = if (fitIntercept) 1.0 else 0.0
+
+ val xi = data.features.toBreeze
val ti = data.label
val delta = data.censor
- val epsilon = (math.log(ti) - beta.dot(xi)) / sigma
+ val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma
lossSum += math.log(sigma) * delta
lossSum += (math.exp(epsilon) - delta * epsilon)
@@ -481,8 +480,10 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
assert(!lossSum.isInfinity,
s"AFTAggregator loss sum is infinity. Error for unknown reason.")
- gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma
- gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon
+ val deltaMinusExpEps = delta - math.exp(epsilon)
+ gradientCoefficientSum += xi * deltaMinusExpEps / sigma
+ gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma
+ gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon
totalCnt += 1
this
@@ -501,7 +502,8 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
totalCnt += other.totalCnt
lossSum += other.lossSum
- gradientBetaSum += other.gradientBetaSum
+ gradientCoefficientSum += other.gradientCoefficientSum
+ gradientInterceptSum += other.gradientInterceptSum
gradientLogSigmaSum += other.gradientLogSigmaSum
}
this