aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala128
1 files changed, 83 insertions, 45 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index ba5708ab8d..89ba6ab5d2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -31,8 +31,9 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.storage.StorageLevel
@@ -103,7 +104,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ SchemaUtils.checkNumericType(schema, $(labelCol))
}
if (hasQuantilesCol) {
SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
@@ -183,24 +184,35 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
* Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
* and put it in an RDD with strong types.
*/
- protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
- dataset.select($(featuresCol), $(labelCol), $(censorCol)).rdd.map {
- case Row(features: Vector, label: Double, censor: Double) =>
- AFTPoint(features, label, censor)
- }
+ protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = {
+ dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol)))
+ .rdd.map {
+ case Row(features: Vector, label: Double, censor: Double) =>
+ AFTPoint(features, label, censor)
+ }
}
- @Since("1.6.0")
- override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): AFTSurvivalRegressionModel = {
validateAndTransformSchema(dataset.schema, fitting = true)
val instances = extractAFTPoints(dataset)
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
- val costFun = new AFTCostFun(instances, $(fitIntercept))
+ val featuresSummarizer = {
+ val seqOp = (c: MultivariateOnlineSummarizer, v: AFTPoint) => c.add(v.features)
+ val combOp = (c1: MultivariateOnlineSummarizer, c2: MultivariateOnlineSummarizer) => {
+ c1.merge(c2)
+ }
+ instances.treeAggregate(new MultivariateOnlineSummarizer)(seqOp, combOp)
+ }
+
+ val featuresStd = featuresSummarizer.variance.toArray.map(math.sqrt)
+
+ val costFun = new AFTCostFun(instances, $(fitIntercept), featuresStd)
val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
- val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size
+ val numFeatures = featuresStd.size
/*
The parameters vector has three parts:
the first element: Double, log(sigma), the log of scale parameter
@@ -229,7 +241,13 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
if (handlePersistence) instances.unpersist()
- val coefficients = Vectors.dense(parameters.slice(2, parameters.length))
+ val rawCoefficients = parameters.slice(2, parameters.length)
+ var i = 0
+ while (i < numFeatures) {
+ rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
+ i += 1
+ }
+ val coefficients = Vectors.dense(rawCoefficients)
val intercept = parameters(1)
val scale = math.exp(parameters(0))
val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
@@ -298,8 +316,8 @@ class AFTSurvivalRegressionModel private[ml] (
math.exp(BLAS.dot(coefficients, features) + intercept)
}
- @Since("1.6.0")
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
val predictUDF = udf { features: Vector => predict(features) }
val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
@@ -433,29 +451,36 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
* @param parameters including three part: The log of scale parameter, the intercept and
* regression coefficients corresponding to the features.
* @param fitIntercept Whether to fit an intercept term.
+ * @param featuresStd The standard deviation values of the features.
*/
-private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
- extends Serializable {
+private class AFTAggregator(
+ parameters: BDV[Double],
+ fitIntercept: Boolean,
+ featuresStd: Array[Double]) extends Serializable {
// the regression coefficients to the covariates
private val coefficients = parameters.slice(2, parameters.length)
- private val intercept = parameters.valueAt(1)
+ private val intercept = parameters(1)
// sigma is the scale parameter of the AFT model
private val sigma = math.exp(parameters(0))
private var totalCnt: Long = 0L
private var lossSum = 0.0
- private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length)
- private var gradientInterceptSum = 0.0
- private var gradientLogSigmaSum = 0.0
+ // Here we optimize loss function over log(sigma), intercept and coefficients
+ private val gradientSumArray = Array.ofDim[Double](parameters.length)
def count: Long = totalCnt
+ def loss: Double = {
+ require(totalCnt > 0.0, s"The number of instances should be " +
+ s"greater than 0.0, but got $totalCnt.")
+ lossSum / totalCnt
+ }
+ def gradient: BDV[Double] = {
+ require(totalCnt > 0.0, s"The number of instances should be " +
+ s"greater than 0.0, but got $totalCnt.")
+ new BDV(gradientSumArray.map(_ / totalCnt.toDouble))
+ }
- def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
-
- // Here we optimize loss function over coefficients, intercept and log(sigma)
- def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
- BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble)
/**
* Add a new training data to this AFTAggregator, and update the loss and gradient
@@ -465,25 +490,32 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* @return This AFTAggregator object.
*/
def add(data: AFTPoint): this.type = {
-
- val interceptFlag = if (fitIntercept) 1.0 else 0.0
-
- val xi = data.features.toBreeze
+ val xi = data.features
val ti = data.label
val delta = data.censor
- val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma
- lossSum += math.log(sigma) * delta
- lossSum += (math.exp(epsilon) - delta * epsilon)
+ val margin = {
+ var sum = 0.0
+ xi.foreachActive { (index, value) =>
+ if (featuresStd(index) != 0.0 && value != 0.0) {
+ sum += coefficients(index) * (value / featuresStd(index))
+ }
+ }
+ sum + intercept
+ }
+ val epsilon = (math.log(ti) - margin) / sigma
+
+ lossSum += delta * math.log(sigma) - delta * epsilon + math.exp(epsilon)
- // Sanity check (should never occur):
- assert(!lossSum.isInfinity,
- s"AFTAggregator loss sum is infinity. Error for unknown reason.")
+ val multiplier = (delta - math.exp(epsilon)) / sigma
- val deltaMinusExpEps = delta - math.exp(epsilon)
- gradientCoefficientSum += xi * deltaMinusExpEps / sigma
- gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma
- gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon
+ gradientSumArray(0) += delta + multiplier * sigma * epsilon
+ gradientSumArray(1) += { if (fitIntercept) multiplier else 0.0 }
+ xi.foreachActive { (index, value) =>
+ if (featuresStd(index) != 0.0 && value != 0.0) {
+ gradientSumArray(index + 2) += multiplier * (value / featuresStd(index))
+ }
+ }
totalCnt += 1
this
@@ -502,9 +534,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
totalCnt += other.totalCnt
lossSum += other.lossSum
- gradientCoefficientSum += other.gradientCoefficientSum
- gradientInterceptSum += other.gradientInterceptSum
- gradientLogSigmaSum += other.gradientLogSigmaSum
+ var i = 0
+ val len = this.gradientSumArray.length
+ while (i < len) {
+ this.gradientSumArray(i) += other.gradientSumArray(i)
+ i += 1
+ }
}
this
}
@@ -515,12 +550,15 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
* It returns the loss and gradient at a particular point (parameters).
* It's used in Breeze's convex optimization routines.
*/
-private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean)
- extends DiffFunction[BDV[Double]] {
+private class AFTCostFun(
+ data: RDD[AFTPoint],
+ fitIntercept: Boolean,
+ featuresStd: Array[Double]) extends DiffFunction[BDV[Double]] {
override def calculate(parameters: BDV[Double]): (Double, BDV[Double]) = {
- val aftAggregator = data.treeAggregate(new AFTAggregator(parameters, fitIntercept))(
+ val aftAggregator = data.treeAggregate(
+ new AFTAggregator(parameters, fitIntercept, featuresStd))(
seqOp = (c, v) => (c, v) match {
case (aggregator, instance) => aggregator.add(instance)
},