aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-17 21:37:10 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-17 21:37:10 -0700
commit98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b (patch)
tree667348315237c6a2543080958789341d75b679f7 /mllib
parent0f5ef6dfa67a068606aff8ea9d1addfce73446eb (diff)
downloadspark-98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b.tar.gz
spark-98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b.tar.bz2
spark-98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b.zip
[SPARK-8518] [ML] Log-linear models for survival analysis
[Accelerated Failure Time (AFT) model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) is the most commonly used and easy to parallel method of survival analysis for censored survival data. It is the log-linear model based on the Weibull distribution of the survival time. Users can refer to the R function [```survreg```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) to compare the model and [```predict```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/predict.survreg.html) to compare the prediction. There are different kinds of model prediction, I have just select the type ```response``` which is default used for R. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8611 from yanboliang/spark-8518.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala449
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala311
2 files changed, 760 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
new file mode 100644
index 0000000000..5b25db651f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -0,0 +1,449 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import scala.collection.mutable
+
+import breeze.linalg.{DenseVector => BDV}
+import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS}
+
+import org.apache.spark.{SparkException, Logging}
+import org.apache.spark.annotation.{Since, Experimental}
+import org.apache.spark.ml.{Model, Estimator}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.{SchemaUtils, Identifiable}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.mllib.linalg.BLAS
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, DataFrame}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, StructType}
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * Params for accelerated failure time (AFT) regression.
+ */
+private[regression] trait AFTSurvivalRegressionParams extends Params
+ with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
+ with HasTol with HasFitIntercept {
+
+ /**
+ * Param for censor column name.
+ * The value of this column could be 0 or 1.
+ * If the value is 1, it means the event has occurred i.e. uncensored; otherwise censored.
+ * @group param
+ */
+ @Since("1.6.0")
+ final val censorCol: Param[String] = new Param(this, "censorCol", "censor column name")
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getCensorCol: String = $(censorCol)
+ setDefault(censorCol -> "censor")
+
+ /**
+ * Param for quantile probabilities array.
+ * Values of the quantile probabilities array should be in the range [0, 1].
+ * @group param
+ */
+ @Since("1.6.0")
+ final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this,
+ "quantileProbabilities", "quantile probabilities array",
+ (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1)))
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getQuantileProbabilities: Array[Double] = $(quantileProbabilities)
+
+ /** Checks whether the input has quantile probabilities array. */
+ protected[regression] def hasQuantileProbabilities: Boolean = {
+ isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0
+ }
+
+ /**
+ * Validates and transforms the input schema with the provided param map.
+ * @param schema input schema
+ * @param fitting whether this is in fitting or prediction
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ if (fitting) {
+ SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
+ SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ }
+ SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Fit a parametric survival regression model named accelerated failure time (AFT) model
+ * ([[https://en.wikipedia.org/wiki/Accelerated_failure_time_model]])
+ * based on the Weibull distribution of the survival time.
+ */
+@Experimental
+@Since("1.6.0")
+class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: String)
+ extends Estimator[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams with Logging {
+
+ @Since("1.6.0")
+ def this() = this(Identifiable.randomUID("aftSurvReg"))
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setLabelCol(value: String): this.type = set(labelCol, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setCensorCol(value: String): this.type = set(censorCol, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /**
+ * Set if we should fit the intercept
+ * Default is true.
+ * @group setParam
+ */
+ @Since("1.6.0")
+ def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
+ setDefault(fitIntercept -> true)
+
+ /**
+ * Set the maximum number of iterations.
+ * Default is 100.
+ * @group setParam
+ */
+ @Since("1.6.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+ setDefault(maxIter -> 100)
+
+ /**
+ * Set the convergence tolerance of iterations.
+ * Smaller value will lead to higher accuracy with the cost of more iterations.
+ * Default is 1E-6.
+ * @group setParam
+ */
+ @Since("1.6.0")
+ def setTol(value: Double): this.type = set(tol, value)
+ setDefault(tol -> 1E-6)
+
+ /**
+ * Extract [[featuresCol]], [[labelCol]] and [[censorCol]] from input dataset,
+ * and put it in an RDD with strong types.
+ */
+ protected[ml] def extractAFTPoints(dataset: DataFrame): RDD[AFTPoint] = {
+ dataset.select($(featuresCol), $(labelCol), $(censorCol)).map {
+ case Row(features: Vector, label: Double, censor: Double) =>
+ AFTPoint(features, label, censor)
+ }
+ }
+
+ @Since("1.6.0")
+ override def fit(dataset: DataFrame): AFTSurvivalRegressionModel = {
+ validateAndTransformSchema(dataset.schema, fitting = true)
+ val instances = extractAFTPoints(dataset)
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val costFun = new AFTCostFun(instances, $(fitIntercept))
+ val optimizer = new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
+
+ val numFeatures = dataset.select($(featuresCol)).take(1)(0).getAs[Vector](0).size
+ /*
+ The weights vector has three parts:
+ the first element: Double, log(sigma), the log of scale parameter
+ the second element: Double, intercept of the beta parameter
+ the third to the end elements: Doubles, regression coefficients vector of the beta parameter
+ */
+ val initialWeights = Vectors.zeros(numFeatures + 2)
+
+ val states = optimizer.iterations(new CachedDiffFunction(costFun),
+ initialWeights.toBreeze.toDenseVector)
+
+ val weights = {
+ val arrayBuilder = mutable.ArrayBuilder.make[Double]
+ var state: optimizer.State = null
+ while (states.hasNext) {
+ state = states.next()
+ arrayBuilder += state.adjustedValue
+ }
+ if (state == null) {
+ val msg = s"${optimizer.getClass.getName} failed."
+ throw new SparkException(msg)
+ }
+
+ state.x.toArray.clone()
+ }
+
+ if (handlePersistence) instances.unpersist()
+
+ val coefficients = Vectors.dense(weights.slice(2, weights.length))
+ val intercept = weights(1)
+ val scale = math.exp(weights(0))
+ val model = new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale)
+ copyValues(model.setParent(this))
+ }
+
+ @Since("1.6.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = true)
+ }
+
+ @Since("1.6.0")
+ override def copy(extra: ParamMap): AFTSurvivalRegression = defaultCopy(extra)
+}
+
+/**
+ * :: Experimental ::
+ * Model produced by [[AFTSurvivalRegression]].
+ */
+@Experimental
+@Since("1.6.0")
+class AFTSurvivalRegressionModel private[ml] (
+ @Since("1.6.0") override val uid: String,
+ @Since("1.6.0") val coefficients: Vector,
+ @Since("1.6.0") val intercept: Double,
+ @Since("1.6.0") val scale: Double)
+ extends Model[AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams {
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
+
+ @Since("1.6.0")
+ def predictQuantiles(features: Vector): Vector = {
+ require(hasQuantileProbabilities,
+ "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array")
+ // scale parameter for the Weibull distribution of lifetime
+ val lambda = math.exp(BLAS.dot(coefficients, features) + intercept)
+ // shape parameter for the Weibull distribution of lifetime
+ val k = 1 / scale
+ val quantiles = $(quantileProbabilities).map {
+ q => lambda * math.exp(math.log(-math.log(1 - q)) / k)
+ }
+ Vectors.dense(quantiles)
+ }
+
+ @Since("1.6.0")
+ def predict(features: Vector): Double = {
+ math.exp(BLAS.dot(coefficients, features) + intercept)
+ }
+
+ @Since("1.6.0")
+ override def transform(dataset: DataFrame): DataFrame = {
+ transformSchema(dataset.schema)
+ val predictUDF = udf { features: Vector => predict(features) }
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
+ @Since("1.6.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema, fitting = false)
+ }
+
+ @Since("1.6.0")
+ override def copy(extra: ParamMap): AFTSurvivalRegressionModel = {
+ copyValues(new AFTSurvivalRegressionModel(uid, coefficients, intercept, scale), extra)
+ .setParent(parent)
+ }
+}
+
+/**
+ * AFTAggregator computes the gradient and loss for a AFT loss function,
+ * as used in AFT survival regression for samples in sparse or dense vector in a online fashion.
+ *
+ * The loss function and likelihood function under the AFT model based on:
+ * Lawless, J. F., Statistical Models and Methods for Lifetime Data,
+ * New York: John Wiley & Sons, Inc. 2003.
+ *
+ * Two AFTAggregator can be merged together to have a summary of loss and gradient of
+ * the corresponding joint dataset.
+ *
+ * Given the values of the covariates x^{'}, for random lifetime t_{i} of subjects i = 1, ..., n,
+ * with possible right-censoring, the likelihood function under the AFT model is given as
+ * {{{
+ * L(\beta,\sigma)=\prod_{i=1}^n[\frac{1}{\sigma}f_{0}
+ * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})]^{\delta_{i}}S_{0}
+ * (\frac{\log{t_{i}}-x^{'}\beta}{\sigma})^{1-\delta_{i}}
+ * }}}
+ * Where \delta_{i} is the indicator of the event has occurred i.e. uncensored or not.
+ * Using \epsilon_{i}=\frac{\log{t_{i}}-x^{'}\beta}{\sigma}, the log-likelihood function
+ * assumes the form
+ * {{{
+ * \iota(\beta,\sigma)=\sum_{i=1}^{n}[-\delta_{i}\log\sigma+
+ * \delta_{i}\log{f_{0}}(\epsilon_{i})+(1-\delta_{i})\log{S_{0}(\epsilon_{i})}]
+ * }}}
+ * Where S_{0}(\epsilon_{i}) is the baseline survivor function,
+ * and f_{0}(\epsilon_{i}) is corresponding density function.
+ *
+ * The most commonly used log-linear survival regression method is based on the Weibull
+ * distribution of the survival time. The Weibull distribution for lifetime corresponding
+ * to extreme value distribution for log of the lifetime,
+ * and the S_{0}(\epsilon) function is
+ * {{{
+ * S_{0}(\epsilon_{i})=\exp(-e^{\epsilon_{i}})
+ * }}}
+ * the f_{0}(\epsilon_{i}) function is
+ * {{{
+ * f_{0}(\epsilon_{i})=e^{\epsilon_{i}}\exp(-e^{\epsilon_{i}})
+ * }}}
+ * The log-likelihood function for Weibull distribution of lifetime is
+ * {{{
+ * \iota(\beta,\sigma)=
+ * -\sum_{i=1}^n[\delta_{i}\log\sigma-\delta_{i}\epsilon_{i}+e^{\epsilon_{i}}]
+ * }}}
+ * Due to minimizing the negative log-likelihood equivalent to maximum a posteriori probability,
+ * the loss function we use to optimize is -\iota(\beta,\sigma).
+ * The gradient functions for \beta and \log\sigma respectively are
+ * {{{
+ * \frac{\partial (-\iota)}{\partial \beta}=
+ * \sum_{1=1}^{n}[\delta_{i}-e^{\epsilon_{i}}]\frac{x_{i}}{\sigma}
+ * }}}
+ * {{{
+ * \frac{\partial (-\iota)}{\partial (\log\sigma)}=
+ * \sum_{i=1}^{n}[\delta_{i}+(\delta_{i}-e^{\epsilon_{i}})\epsilon_{i}]
+ * }}}
+ * @param weights The log of scale parameter, the intercept and
+ * regression coefficients corresponding to the features.
+ * @param fitIntercept Whether to fit an intercept term.
+ */
+private class AFTAggregator(weights: BDV[Double], fitIntercept: Boolean)
+ extends Serializable {
+
+ // beta is the intercept and regression coefficients to the covariates
+ private val beta = weights.slice(1, weights.length)
+ // sigma is the scale parameter of the AFT model
+ private val sigma = math.exp(weights(0))
+
+ private var totalCnt: Long = 0L
+ private var lossSum = 0.0
+ private var gradientBetaSum = BDV.zeros[Double](beta.length)
+ private var gradientLogSigmaSum = 0.0
+
+ def count: Long = totalCnt
+
+ def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
+
+ // Here we optimize loss function over beta and log(sigma)
+ def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
+ gradientBetaSum/totalCnt.toDouble)
+
+ /**
+ * Add a new training data to this AFTAggregator, and update the loss and gradient
+ * of the objective function.
+ *
+ * @param data The AFTPoint representation for one data point to be added into this aggregator.
+ * @return This AFTAggregator object.
+ */
+ def add(data: AFTPoint): this.type = {
+
+ // TODO: Don't create a new xi vector each time.
+ val xi = if (fitIntercept) {
+ Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze
+ } else {
+ Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze
+ }
+ val ti = data.label
+ val delta = data.censor
+ val epsilon = (math.log(ti) - beta.dot(xi)) / sigma
+
+ lossSum += math.log(sigma) * delta
+ lossSum += (math.exp(epsilon) - delta * epsilon)
+
+ // Sanity check (should never occur):
+ assert(!lossSum.isInfinity,
+ s"AFTAggregator loss sum is infinity. Error for unknown reason.")
+
+ gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma
+ gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon
+
+ totalCnt += 1
+ this
+ }
+
+ /**
+ * Merge another AFTAggregator, and update the loss and gradient
+ * of the objective function.
+ * (Note that it's in place merging; as a result, `this` object will be modified.)
+ *
+ * @param other The other AFTAggregator to be merged.
+ * @return This AFTAggregator object.
+ */
+ def merge(other: AFTAggregator): this.type = {
+ if (totalCnt != 0) {
+ totalCnt += other.totalCnt
+ lossSum += other.lossSum
+
+ gradientBetaSum += other.gradientBetaSum
+ gradientLogSigmaSum += other.gradientLogSigmaSum
+ }
+ this
+ }
+}
+
+/**
+ * AFTCostFun implements Breeze's DiffFunction[T] for AFT cost.
+ * It returns the loss and gradient at a particular point (coefficients).
+ * It's used in Breeze's convex optimization routines.
+ */
+private class AFTCostFun(data: RDD[AFTPoint], fitIntercept: Boolean)
+ extends DiffFunction[BDV[Double]] {
+
+ override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
+
+ val aftAggregator = data.treeAggregate(new AFTAggregator(coefficients, fitIntercept))(
+ seqOp = (c, v) => (c, v) match {
+ case (aggregator, instance) => aggregator.add(instance)
+ },
+ combOp = (c1, c2) => (c1, c2) match {
+ case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
+ })
+
+ (aftAggregator.loss, aftAggregator.gradient)
+ }
+}
+
+/**
+ * Class that represents the (features, label, censor) of a data point.
+ *
+ * @param features List of features for this data point.
+ * @param label Label for this data point.
+ * @param censor Indicator of the event has occurred or not. If the value is 1, it means
+ * the event has occurred i.e. uncensored; otherwise censored.
+ */
+private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) {
+ require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0")
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
new file mode 100644
index 0000000000..ca7140a45e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS
+import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{Row, DataFrame}
+
+class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ @transient var datasetUnivariate: DataFrame = _
+ @transient var datasetMultivariate: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ datasetUnivariate = sqlContext.createDataFrame(
+ sc.parallelize(generateAFTInput(
+ 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)))
+ datasetMultivariate = sqlContext.createDataFrame(
+ sc.parallelize(generateAFTInput(
+ 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0)))
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new AFTSurvivalRegression)
+ val model = new AFTSurvivalRegressionModel("aftSurvReg", Vectors.dense(0.0), 0.0, 0.0)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("aft survival regression: default params") {
+ val aftr = new AFTSurvivalRegression
+ assert(aftr.getLabelCol === "label")
+ assert(aftr.getFeaturesCol === "features")
+ assert(aftr.getPredictionCol === "prediction")
+ assert(aftr.getCensorCol === "censor")
+ assert(aftr.getFitIntercept)
+ assert(aftr.getMaxIter === 100)
+ assert(aftr.getTol === 1E-6)
+ val model = aftr.fit(datasetUnivariate)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
+ model.transform(datasetUnivariate)
+ .select("label", "prediction")
+ .collect()
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.intercept !== 0.0)
+ assert(model.hasParent)
+ }
+
+ def generateAFTInput(
+ numFeatures: Int,
+ xMean: Array[Double],
+ xVariance: Array[Double],
+ nPoints: Int,
+ seed: Int,
+ weibullShape: Double,
+ weibullScale: Double,
+ exponentialMean: Double): Seq[AFTPoint] = {
+
+ def censor(x: Double, y: Double): Double = { if (x <= y) 1.0 else 0.0 }
+
+ val weibull = new WeibullGenerator(weibullShape, weibullScale)
+ weibull.setSeed(seed)
+
+ val exponential = new ExponentialGenerator(exponentialMean)
+ exponential.setSeed(seed)
+
+ val rnd = new Random(seed)
+ val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](numFeatures)(rnd.nextDouble()))
+
+ x.foreach { v =>
+ var i = 0
+ val len = v.length
+ while (i < len) {
+ v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
+ i += 1
+ }
+ }
+ val y = (1 to nPoints).map { i => (weibull.nextValue(), exponential.nextValue()) }
+
+ y.zip(x).map { p => AFTPoint(Vectors.dense(p._2), p._1._1, censor(p._1._1, p._1._2)) }
+ }
+
+ test("aft survival regression with univariate") {
+ val trainer = new AFTSurvivalRegression
+ val model = trainer.fit(datasetUnivariate)
+
+ /*
+ Using the following R code to load the data and train the model using survival package.
+
+ library("survival")
+ data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+ features <- data$V1
+ censor <- data$V2
+ label <- data$V3
+ sr.fit <- survreg(Surv(label, censor) ~ features, dist='weibull')
+ summary(sr.fit)
+
+ Value Std. Error z p
+ (Intercept) 1.759 0.4141 4.247 2.16e-05
+ features -0.039 0.0735 -0.531 5.96e-01
+ Log(scale) 0.344 0.0379 9.073 1.16e-19
+
+ Scale= 1.41
+
+ Weibull distribution
+ Loglik(model)= -1152.2 Loglik(intercept only)= -1152.3
+ Chisq= 0.28 on 1 degrees of freedom, p= 0.6
+ Number of Newton-Raphson Iterations: 5
+ n= 1000
+ */
+ val coefficientsR = Vectors.dense(-0.039)
+ val interceptR = 1.759
+ val scaleR = 1.41
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.coefficients ~== coefficientsR relTol 1E-3)
+ assert(model.scale ~== scaleR relTol 1E-3)
+
+ /*
+ Using the following R code to predict.
+
+ testdata <- list(features=6.559282795753792)
+ responsePredict <- predict(sr.fit, newdata=testdata)
+ responsePredict
+
+ 1
+ 4.494763
+
+ quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9))
+ quantilePredict
+
+ [1] 0.1879174 2.6801195 14.5779394
+ */
+ val features = Vectors.dense(6.559282795753792)
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val responsePredictR = 4.494763
+ val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394)
+
+ assert(model.predict(features) ~== responsePredictR relTol 1E-3)
+ model.setQuantileProbabilities(quantileProbabilities)
+ assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
+
+ model.transform(datasetUnivariate).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
+
+ test("aft survival regression with multivariate") {
+ val trainer = new AFTSurvivalRegression
+ val model = trainer.fit(datasetMultivariate)
+
+ /*
+ Using the following R code to load the data and train the model using survival package.
+
+ library("survival")
+ data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+ feature1 <- data$V1
+ feature2 <- data$V2
+ censor <- data$V3
+ label <- data$V4
+ sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2, dist='weibull')
+ summary(sr.fit)
+
+ Value Std. Error z p
+ (Intercept) 1.9206 0.1057 18.171 8.78e-74
+ feature1 -0.0844 0.0611 -1.381 1.67e-01
+ feature2 0.0677 0.0468 1.447 1.48e-01
+ Log(scale) -0.0236 0.0436 -0.542 5.88e-01
+
+ Scale= 0.977
+
+ Weibull distribution
+ Loglik(model)= -1070.7 Loglik(intercept only)= -1072.7
+ Chisq= 3.91 on 2 degrees of freedom, p= 0.14
+ Number of Newton-Raphson Iterations: 5
+ n= 1000
+ */
+ val coefficientsR = Vectors.dense(-0.0844, 0.0677)
+ val interceptR = 1.9206
+ val scaleR = 0.977
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.coefficients ~== coefficientsR relTol 1E-3)
+ assert(model.scale ~== scaleR relTol 1E-3)
+
+ /*
+ Using the following R code to predict.
+ testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683)
+ responsePredict <- predict(sr.fit, newdata=testdata)
+ responsePredict
+
+ 1
+ 4.761219
+
+ quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9))
+ quantilePredict
+
+ [1] 0.5287044 3.3285858 10.7517072
+ */
+ val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val responsePredictR = 4.761219
+ val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072)
+
+ assert(model.predict(features) ~== responsePredictR relTol 1E-3)
+ model.setQuantileProbabilities(quantileProbabilities)
+ assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
+
+ model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
+
+ test("aft survival regression w/o intercept") {
+ val trainer = new AFTSurvivalRegression().setFitIntercept(false)
+ val model = trainer.fit(datasetMultivariate)
+
+ /*
+ Using the following R code to load the data and train the model using survival package.
+
+ library("survival")
+ data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+ feature1 <- data$V1
+ feature2 <- data$V2
+ censor <- data$V3
+ label <- data$V4
+ sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2 - 1, dist='weibull')
+ summary(sr.fit)
+
+ Value Std. Error z p
+ feature1 0.896 0.0685 13.1 3.93e-39
+ feature2 -0.709 0.0522 -13.6 5.78e-42
+ Log(scale) 0.420 0.0401 10.5 1.23e-25
+
+ Scale= 1.52
+
+ Weibull distribution
+ Loglik(model)= -1292.4 Loglik(intercept only)= -1072.7
+ Chisq= -439.57 on 1 degrees of freedom, p= 1
+ Number of Newton-Raphson Iterations: 6
+ n= 1000
+ */
+ val coefficientsR = Vectors.dense(0.896, -0.709)
+ val interceptR = 0.0
+ val scaleR = 1.52
+
+ assert(model.intercept === interceptR)
+ assert(model.coefficients ~== coefficientsR relTol 1E-3)
+ assert(model.scale ~== scaleR relTol 1E-3)
+
+ /*
+ Using the following R code to predict.
+ testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683)
+ responsePredict <- predict(sr.fit, newdata=testdata)
+ responsePredict
+
+ 1
+ 44.54465
+
+ quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9))
+ quantilePredict
+
+ [1] 1.452103 25.506077 158.428600
+ */
+ val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val responsePredictR = 44.54465
+ val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600)
+
+ assert(model.predict(features) ~== responsePredictR relTol 1E-3)
+ model.setQuantileProbabilities(quantileProbabilities)
+ assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
+
+ model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
+}