aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-08-09 03:39:57 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-08-09 03:39:57 -0700
commit182e11904bf2093c2faa57894a1c4bb11d872596 (patch)
treeed32964fc35e5626ccc698de03a67d30d2e3c0d0 /mllib
parent511f52f8423e151b0d0133baf040d34a0af3d422 (diff)
downloadspark-182e11904bf2093c2faa57894a1c4bb11d872596.tar.gz
spark-182e11904bf2093c2faa57894a1c4bb11d872596.tar.bz2
spark-182e11904bf2093c2faa57894a1c4bb11d872596.zip
[SPARK-16933][ML] Fix AFTAggregator in AFTSurvivalRegression serializes unnecessary data.
## What changes were proposed in this pull request? Similar to ```LeastSquaresAggregator``` in #14109, ```AFTAggregator``` used for ```AFTSurvivalRegression``` ends up serializing the ```parameters``` and ```featuresStd```, which is not necessary and can cause performance issues for high dimensional data. This patch removes this serialization. This PR is highly inspired by #14109. ## How was this patch tested? I tested this locally and verified the serialization reduction. Before patch ![image](https://cloud.githubusercontent.com/assets/1962026/17512035/abb93f04-5dda-11e6-97d3-8ae6b61a0dfd.png) After patch ![image](https://cloud.githubusercontent.com/assets/1962026/17512024/9e0dc44c-5dda-11e6-93d0-6e130ba0d6aa.png) Author: Yanbo Liang <ybliang8@gmail.com> Closes #14519 from yanboliang/spark-16933.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala47
1 files changed, 29 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index be234f7fea..3179f4882f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.linalg.{BLAS, Vector, Vectors, VectorUDT}
@@ -219,7 +220,9 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
"columns. This behavior is different from R survival::survreg.")
}
- val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd)
+ val bcFeaturesStd = instances.context.broadcast(featuresStd)
+
+ val costFun = new AFTCostFun(instances, $(fitIntercept), bcFeaturesStd)
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
/*
@@ -247,6 +250,7 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
state.x.toArray.clone()
}
+ bcFeaturesStd.destroy(blocking = false)
if (handlePersistence) instances.unpersist()
val rawCoefficients = parameters.slice(2, parameters.length)
@@ -478,26 +482,29 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
* $$
* </blockquote></p>
*
- * @param parameters including three part: The log of scale parameter, the intercept and
- * regression coefficients corresponding to the features.
+ * @param bcParameters The broadcasted value includes three part: The log of scale parameter,
+ * the intercept and regression coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term.
- * @param featuresStd The standard deviation values of the features.
+ * @param bcFeaturesStd The broadcast standard deviation values of the features.
*/
private class AFTAggregator(
- parameters: BDV[Double],
+ bcParameters: Broadcast[BDV[Double]],
fitIntercept: Boolean,
- featuresStd: Array[Double]) extends Serializable {
+ bcFeaturesStd: Broadcast[Array[Double]]) extends Serializable {
+ private val length = bcParameters.value.length
+ // make transient so we do not serialize between aggregation stages
+ @transient private lazy val parameters = bcParameters.value
// the regression coefficients to the covariates
- private val coefficients = parameters.slice(2, parameters.length)
- private val intercept = parameters(1)
+ @transient private lazy val coefficients = parameters.slice(2, length)
+ @transient private lazy val intercept = parameters(1)
// sigma is the scale parameter of the AFT model
- private val sigma = math.exp(parameters(0))
+ @transient private lazy val sigma = math.exp(parameters(0))
private var totalCnt: Long = 0L
private var lossSum = 0.0
// Here we optimize loss function over log(sigma), intercept and coefficients
- private val gradientSumArray = Array.ofDim[Double](parameters.length)
+ private val gradientSumArray = Array.ofDim[Double](length)
def count: Long = totalCnt
def loss: Double = {
@@ -524,11 +531,13 @@ private class AFTAggregator(
val ti = data.label
val delta = data.censor
+ val localFeaturesStd = bcFeaturesStd.value
+
val margin = {
var sum = 0.0
xi.foreachActive { (index, value) =>
- if (featuresStd(index) != 0.0 && value != 0.0) {
- sum += coefficients(index) * (value / featuresStd(index))
+ if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+ sum += coefficients(index) * (value / localFeaturesStd(index))
}
}
sum + intercept
@@ -542,8 +551,8 @@ private class AFTAggregator(
gradientSumArray(0) += delta + multiplier * sigma * epsilon
gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 }
xi.foreachActive { (index, value) =>
- if (featuresStd(index) != 0.0 && value != 0.0) {
- gradientSumArray(index + 2) += multiplier * (value / featuresStd(index))
+ if (localFeaturesStd(index) != 0.0 && value != 0.0) {
+ gradientSumArray(index + 2) += multiplier * (value / localFeaturesStd(index))
}
}
@@ -565,8 +574,7 @@ private class AFTAggregator(
lossSum += other.lossSum
var i = 0
- val len = this.gradientSumArray.length
- while (i < len) {
+ while (i < length) {
this.gradientSumArray(i) += other.gradientSumArray(i)
i += 1
}
@@ -583,12 +591,14 @@ private class AFTAggregator(
private class AFTCostFun(
data: RDD[AFTPoint],
fitIntercept: Boolean,
- featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] {
+ bcFeaturesStd: Broadcast[Array[Double]]) extends DiffFunction[BDV[Double]] {
override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
+ val bcParameters = data.context.broadcast(parameters)
+
val aftAggregator = data.treeAggregate(
- new AFTAggregator(parameters, fitIntercept, featuresStd))(
+ new AFTAggregator(bcParameters, fitIntercept, bcFeaturesStd))(
seqOp = (c, v) => (c, v) match {
case (aggregator, instance) => aggregator.add(instance)
},
@@ -596,6 +606,7 @@ private class AFTCostFun(
case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
})
+ bcParameters.destroy(blocking = false)
(aftAggregator.loss, aftAggregator.gradient)
}
}