diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-09-17 21:37:10 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-09-17 21:37:10 -0700 |
commit | 98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b (patch) | |
tree | 667348315237c6a2543080958789341d75b679f7 /mllib/src/test/scala/org/apache | |
parent | 0f5ef6dfa67a068606aff8ea9d1addfce73446eb (diff) | |
download | spark-98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b.tar.gz spark-98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b.tar.bz2 spark-98f1ea67da1b0e3aa791c3cbfa06e48e2ba0d75b.zip |
[SPARK-8518] [ML] Log-linear models for survival analysis
[Accelerated Failure Time (AFT) model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) is the most commonly used and easy to parallel method of survival analysis for censored survival data. It is the log-linear model based on the Weibull distribution of the survival time.
Users can refer to the R function [```survreg```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/survreg.html) to compare the model and [```predict```](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/predict.survreg.html) to compare the prediction. There are different kinds of model prediction, I have just select the type ```response``` which is default used for R.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #8611 from yanboliang/spark-8518.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala | 311 |
1 files changed, 311 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala new file mode 100644 index 0000000000..ca7140a45e --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -0,0 +1,311 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.regression + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.{DenseVector, Vectors} +import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Row, DataFrame} + +class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext { + + @transient var datasetUnivariate: DataFrame = _ + @transient var datasetMultivariate: DataFrame = _ + + override def beforeAll(): Unit = { + super.beforeAll() + datasetUnivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0))) + datasetMultivariate = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + } + + test("params") { + ParamsSuite.checkParams(new AFTSurvivalRegression) + val model = new AFTSurvivalRegressionModel("aftSurvReg", Vectors.dense(0.0), 0.0, 0.0) + ParamsSuite.checkParams(model) + } + + test("aft survival regression: default params") { + val aftr = new AFTSurvivalRegression + assert(aftr.getLabelCol === "label") + assert(aftr.getFeaturesCol === "features") + assert(aftr.getPredictionCol === "prediction") + assert(aftr.getCensorCol === "censor") + assert(aftr.getFitIntercept) + assert(aftr.getMaxIter === 100) + assert(aftr.getTol === 1E-6) + val model = aftr.fit(datasetUnivariate) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + + model.transform(datasetUnivariate) + .select("label", "prediction") + .collect() + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + } + + def generateAFTInput( + numFeatures: Int, + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + weibullShape: Double, + weibullScale: Double, + exponentialMean: Double): Seq[AFTPoint] = { + + def censor(x: Double, y: Double): Double = { if (x <= y) 1.0 else 0.0 } + + val weibull = new WeibullGenerator(weibullShape, weibullScale) + weibull.setSeed(seed) + + val exponential = new ExponentialGenerator(exponentialMean) + exponential.setSeed(seed) + + val rnd = new Random(seed) + val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](numFeatures)(rnd.nextDouble())) + + x.foreach { v => + var i = 0 + val len = v.length + while (i < len) { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + i += 1 + } + } + val y = (1 to nPoints).map { i => (weibull.nextValue(), exponential.nextValue()) } + + y.zip(x).map { p => AFTPoint(Vectors.dense(p._2), p._1._1, censor(p._1._1, p._1._2)) } + } + + test("aft survival regression with univariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetUnivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + features <- data$V1 + censor <- data$V2 + label <- data$V3 + sr.fit <- survreg(Surv(label, censor) ~ features, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.759 0.4141 4.247 2.16e-05 + features -0.039 0.0735 -0.531 5.96e-01 + Log(scale) 0.344 0.0379 9.073 1.16e-19 + + Scale= 1.41 + + Weibull distribution + Loglik(model)= -1152.2 Loglik(intercept only)= -1152.3 + Chisq= 0.28 on 1 degrees of freedom, p= 0.6 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.039) + val interceptR = 1.759 + val scaleR = 1.41 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + + testdata <- list(features=6.559282795753792) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.494763 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.1879174 2.6801195 14.5779394 + */ + val features = Vectors.dense(6.559282795753792) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.494763 + val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetUnivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression with multivariate") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + (Intercept) 1.9206 0.1057 18.171 8.78e-74 + feature1 -0.0844 0.0611 -1.381 1.67e-01 + feature2 0.0677 0.0468 1.447 1.48e-01 + Log(scale) -0.0236 0.0436 -0.542 5.88e-01 + + Scale= 0.977 + + Weibull distribution + Loglik(model)= -1070.7 Loglik(intercept only)= -1072.7 + Chisq= 3.91 on 2 degrees of freedom, p= 0.14 + Number of Newton-Raphson Iterations: 5 + n= 1000 + */ + val coefficientsR = Vectors.dense(-0.0844, 0.0677) + val interceptR = 1.9206 + val scaleR = 0.977 + + assert(model.intercept ~== interceptR relTol 1E-3) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 4.761219 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 0.5287044 3.3285858 10.7517072 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 4.761219 + val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } + + test("aft survival regression w/o intercept") { + val trainer = new AFTSurvivalRegression().setFitIntercept(false) + val model = trainer.fit(datasetMultivariate) + + /* + Using the following R code to load the data and train the model using survival package. + + library("survival") + data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE) + feature1 <- data$V1 + feature2 <- data$V2 + censor <- data$V3 + label <- data$V4 + sr.fit <- survreg(Surv(label, censor) ~ feature1 + feature2 - 1, dist='weibull') + summary(sr.fit) + + Value Std. Error z p + feature1 0.896 0.0685 13.1 3.93e-39 + feature2 -0.709 0.0522 -13.6 5.78e-42 + Log(scale) 0.420 0.0401 10.5 1.23e-25 + + Scale= 1.52 + + Weibull distribution + Loglik(model)= -1292.4 Loglik(intercept only)= -1072.7 + Chisq= -439.57 on 1 degrees of freedom, p= 1 + Number of Newton-Raphson Iterations: 6 + n= 1000 + */ + val coefficientsR = Vectors.dense(0.896, -0.709) + val interceptR = 0.0 + val scaleR = 1.52 + + assert(model.intercept === interceptR) + assert(model.coefficients ~== coefficientsR relTol 1E-3) + assert(model.scale ~== scaleR relTol 1E-3) + + /* + Using the following R code to predict. + testdata <- list(feature1=2.233396950271428, feature2=-2.5321374085997683) + responsePredict <- predict(sr.fit, newdata=testdata) + responsePredict + + 1 + 44.54465 + + quantilePredict <- predict(sr.fit, newdata=testdata, type='quantile', p=c(0.1, 0.5, 0.9)) + quantilePredict + + [1] 1.452103 25.506077 158.428600 + */ + val features = Vectors.dense(2.233396950271428, -2.5321374085997683) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val responsePredictR = 44.54465 + val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600) + + assert(model.predict(features) ~== responsePredictR relTol 1E-3) + model.setQuantileProbabilities(quantileProbabilities) + assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) + + model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { + case Row(features: DenseVector, prediction1: Double) => + val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) + assert(prediction1 ~== prediction2 relTol 1E-5) + } + } +} |