aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-17 21:37:10 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-17 21:37:10 -0700
commit98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b (patch)
tree667348315237c6a2543080958789341d75b679f7 /mllib/src/test/scala/org
parent0f5ef6dfa67a068606aff8ea9d1addfce73446eb (diff)
downloadspark-98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b.tar.gz
spark-98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b.tar.bz2
spark-98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b.zip
[SPARK-8518] [ML] Log-linear models for survival analysis
[Accelerated Failure Time (AFT) model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) is the most commonly used and easy to parallel method of survival analysis for censored survival data. It is the log-linear model based on the Weibull distribution of the survival time. Users can refer to the R function [```survreg```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) to compare the model and [```predict```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/predict.survreg.html) to compare the prediction. There are different kinds of model prediction, I have just select the type ```response``` which is default used for R. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8611 from yanboliang/spark-8518.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala311
1 files changed, 311 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
new file mode 100644
index 0000000000..ca7140a45e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -0,0 +1,311 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS
+import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{Row, DataFrame}
+
+class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ @transient var datasetUnivariate: DataFrame = _
+ @transient var datasetMultivariate: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ datasetUnivariate = sqlContext.createDataFrame(
+ sc.parallelize(generateAFTInput(
+ 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)))
+ datasetMultivariate = sqlContext.createDataFrame(
+ sc.parallelize(generateAFTInput(
+ 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0)))
+ }
+
+ test("params") {
+ ParamsSuite.checkParams(new AFTSurvivalRegression)
+ val model = new AFTSurvivalRegressionModel("aftSurvReg", Vectors.dense(0.0), 0.0, 0.0)
+ ParamsSuite.checkParams(model)
+ }
+
+ test("aft survival regression: default params") {
+ val aftr = new AFTSurvivalRegression
+ assert(aftr.getLabelCol === "label")
+ assert(aftr.getFeaturesCol === "features")
+ assert(aftr.getPredictionCol === "prediction")
+ assert(aftr.getCensorCol === "censor")
+ assert(aftr.getFitIntercept)
+ assert(aftr.getMaxIter === 100)
+ assert(aftr.getTol === 1E-6)
+ val model = aftr.fit(datasetUnivariate)
+
+ // copied model must have the same parent.
+ MLTestingUtils.checkCopy(model)
+
+ model.transform(datasetUnivariate)
+ .select("label", "prediction")
+ .collect()
+ assert(model.getFeaturesCol === "features")
+ assert(model.getPredictionCol === "prediction")
+ assert(model.intercept !== 0.0)
+ assert(model.hasParent)
+ }
+
+ def generateAFTInput(
+ numFeatures: Int,
+ xMean: Array[Double],
+ xVariance: Array[Double],
+ nPoints: Int,
+ seed: Int,
+ weibullShape: Double,
+ weibullScale: Double,
+ exponentialMean: Double): Seq[AFTPoint] = {
+
+ def censor(x: Double, y: Double): Double = { if (x <= y) 1.0 else 0.0 }
+
+ val weibull = new WeibullGenerator(weibullShape, weibullScale)
+ weibull.setSeed(seed)
+
+ val exponential = new ExponentialGenerator(exponentialMean)
+ exponential.setSeed(seed)
+
+ val rnd = new Random(seed)
+ val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](numFeatures)(rnd.nextDouble()))
+
+ x.foreach { v =>
+ var i = 0
+ val len = v.length
+ while (i < len) {
+ v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
+ i += 1
+ }
+ }
+ val y = (1 to nPoints).map { i => (weibull.nextValue(), exponential.nextValue()) }
+
+ y.zip(x).map { p => AFTPoint(Vectors.dense(p._2), p._1._1, censor(p._1._1, p._1._2)) }
+ }
+
+ test("aft survival regression with univariate") {
+ val trainer = new AFTSurvivalRegression
+ val model = trainer.fit(datasetUnivariate)
+
+ /*
+ Using the following R code to load the data and train the model using survival package.
+
+ library("survival")
+ data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+ features <- data$V1
+ censor <- data$V2
+ label <- data$V3
+ sr.fit <- survreg(Surv(label, censor) ~ features, dist='weibull')
+ summary(sr.fit)
+
+ Value Std. Error z p
+ (Intercept) 1.759 0.4141 4.247 2.16e-05
+ features -0.039 0.0735 -0.531 5.96e-01
+ Log(scale) 0.344 0.0379 9.073 1.16e-19
+
+ Scale= 1.41
+
+ Weibull distribution
+ Loglik(model)= -1152.2 Loglik(intercept only)= -1152.3
+ Chisq= 0.28 on 1 degrees of freedom, p= 0.6
+ Number of Newton-Raphson Iterations: 5
+ n= 1000
+ */
+ val coefficientsR = Vectors.dense(-0.039)
+ val interceptR = 1.759
+ val scaleR = 1.41
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.coefficients ~== coefficientsR relTol 1E-3)
+ assert(model.scale ~== scaleR relTol 1E-3)
+
+ /*
+ Using the following R code to predict.
+
+ testdata <- list(features=6.559282795753792)
+ responsePredict <- predict(sr.fit, newdata=testdata)
+ responsePredict
+
+ 1
+ 4.494763
+
+ quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9))
+ quantilePredict
+
+ [1] 0.1879174 2.6801195 14.5779394
+ */
+ val features = Vectors.dense(6.559282795753792)
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val responsePredictR = 4.494763
+ val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394)
+
+ assert(model.predict(features) ~== responsePredictR relTol 1E-3)
+ model.setQuantileProbabilities(quantileProbabilities)
+ assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
+
+ model.transform(datasetUnivariate).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
+
+ test("aft survival regression with multivariate") {
+ val trainer = new AFTSurvivalRegression
+ val model = trainer.fit(datasetMultivariate)
+
+ /*
+ Using the following R code to load the data and train the model using survival package.
+
+ library("survival")
+ data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+ feature1 <- data$V1
+ feature2 <- data$V2
+ censor <- data$V3
+ label <- data$V4
+ sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2, dist='weibull')
+ summary(sr.fit)
+
+ Value Std. Error z p
+ (Intercept) 1.9206 0.1057 18.171 8.78e-74
+ feature1 -0.0844 0.0611 -1.381 1.67e-01
+ feature2 0.0677 0.0468 1.447 1.48e-01
+ Log(scale) -0.0236 0.0436 -0.542 5.88e-01
+
+ Scale= 0.977
+
+ Weibull distribution
+ Loglik(model)= -1070.7 Loglik(intercept only)= -1072.7
+ Chisq= 3.91 on 2 degrees of freedom, p= 0.14
+ Number of Newton-Raphson Iterations: 5
+ n= 1000
+ */
+ val coefficientsR = Vectors.dense(-0.0844, 0.0677)
+ val interceptR = 1.9206
+ val scaleR = 0.977
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.coefficients ~== coefficientsR relTol 1E-3)
+ assert(model.scale ~== scaleR relTol 1E-3)
+
+ /*
+ Using the following R code to predict.
+ testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683)
+ responsePredict <- predict(sr.fit, newdata=testdata)
+ responsePredict
+
+ 1
+ 4.761219
+
+ quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9))
+ quantilePredict
+
+ [1] 0.5287044 3.3285858 10.7517072
+ */
+ val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val responsePredictR = 4.761219
+ val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072)
+
+ assert(model.predict(features) ~== responsePredictR relTol 1E-3)
+ model.setQuantileProbabilities(quantileProbabilities)
+ assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
+
+ model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
+
+ test("aft survival regression w/o intercept") {
+ val trainer = new AFTSurvivalRegression().setFitIntercept(false)
+ val model = trainer.fit(datasetMultivariate)
+
+ /*
+ Using the following R code to load the data and train the model using survival package.
+
+ library("survival")
+ data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+ feature1 <- data$V1
+ feature2 <- data$V2
+ censor <- data$V3
+ label <- data$V4
+ sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2 - 1, dist='weibull')
+ summary(sr.fit)
+
+ Value Std. Error z p
+ feature1 0.896 0.0685 13.1 3.93e-39
+ feature2 -0.709 0.0522 -13.6 5.78e-42
+ Log(scale) 0.420 0.0401 10.5 1.23e-25
+
+ Scale= 1.52
+
+ Weibull distribution
+ Loglik(model)= -1292.4 Loglik(intercept only)= -1072.7
+ Chisq= -439.57 on 1 degrees of freedom, p= 1
+ Number of Newton-Raphson Iterations: 6
+ n= 1000
+ */
+ val coefficientsR = Vectors.dense(0.896, -0.709)
+ val interceptR = 0.0
+ val scaleR = 1.52
+
+ assert(model.intercept === interceptR)
+ assert(model.coefficients ~== coefficientsR relTol 1E-3)
+ assert(model.scale ~== scaleR relTol 1E-3)
+
+ /*
+ Using the following R code to predict.
+ testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683)
+ responsePredict <- predict(sr.fit, newdata=testdata)
+ responsePredict
+
+ 1
+ 44.54465
+
+ quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9))
+ quantilePredict
+
+ [1] 1.452103 25.506077 158.428600
+ */
+ val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val responsePredictR = 44.54465
+ val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600)
+
+ assert(model.predict(features) ~== responsePredictR relTol 1E-3)
+ model.setQuantileProbabilities(quantileProbabilities)
+ assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
+
+ model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
+}