diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-09-23 15:26:02 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-09-23 15:26:02 -0700 |
commit | ce2b056d35c0c75d5c162b93680ee2d84152e911 (patch) | |
tree | 66151913c8efefea478dfb6f8c17067b56e1b673 | |
parent | 098be27ad53c485ee2fc7f5871c47f899020e87b (diff) | |
download | spark-ce2b056d35c0c75d5c162b93680ee2d84152e911.tar.gz spark-ce2b056d35c0c75d5c162b93680ee2d84152e911.tar.bz2 spark-ce2b056d35c0c75d5c162b93680ee2d84152e911.zip |
[SPARK-10686] [ML] Add quantilesCol to AFTSurvivalRegression
By default ```quantilesCol``` should be empty. If ```quantileProbabilities``` is set, we should append quantiles as a new column (of type Vector).
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #8836 from yanboliang/spark-10686.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 51 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala | 74 |
2 files changed, 91 insertions, 34 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 5b25db651f..717caacad3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -41,7 +41,7 @@ import org.apache.spark.storage.StorageLevel */ private[regression] trait AFTSurvivalRegressionParams extends Params with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter - with HasTol with HasFitIntercept { + with HasTol with HasFitIntercept with Logging { /** * Param for censor column name. @@ -59,21 +59,35 @@ private[regression] trait AFTSurvivalRegressionParams extends Params /** * Param for quantile probabilities array. - * Values of the quantile probabilities array should be in the range [0, 1]. + * Values of the quantile probabilities array should be in the range [0, 1] + * and the array should be non-empty. * @group param */ @Since("1.6.0") final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this, "quantileProbabilities", "quantile probabilities array", - (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1))) + (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1)) && t.length > 0) /** @group getParam */ @Since("1.6.0") def getQuantileProbabilities: Array[Double] = $(quantileProbabilities) + setDefault(quantileProbabilities -> Array(0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99)) - /** Checks whether the input has quantile probabilities array. */ - protected[regression] def hasQuantileProbabilities: Boolean = { - isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0 + /** + * Param for quantiles column name. + * This column will output quantiles of corresponding quantileProbabilities if it is set. + * @group param + */ + @Since("1.6.0") + final val quantilesCol: Param[String] = new Param(this, "quantilesCol", "quantiles column name") + + /** @group getParam */ + @Since("1.6.0") + def getQuantilesCol: String = $(quantilesCol) + + /** Checks whether the input has quantiles column name. */ + protected[regression] def hasQuantilesCol: Boolean = { + isDefined(quantilesCol) && $(quantilesCol) != "" } /** @@ -90,6 +104,9 @@ private[regression] trait AFTSurvivalRegressionParams extends Params SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) } + if (hasQuantilesCol) { + SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT) + } SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType) } } @@ -124,6 +141,14 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S @Since("1.6.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) + /** @group setParam */ + @Since("1.6.0") + def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + + /** @group setParam */ + @Since("1.6.0") + def setQuantilesCol(value: String): this.type = set(quantilesCol, value) + /** * Set if we should fit the intercept * Default is true. @@ -243,10 +268,12 @@ class AFTSurvivalRegressionModel private[ml] ( @Since("1.6.0") def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value) + /** @group setParam */ + @Since("1.6.0") + def setQuantilesCol(value: String): this.type = set(quantilesCol, value) + @Since("1.6.0") def predictQuantiles(features: Vector): Vector = { - require(hasQuantileProbabilities, - "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array") // scale parameter for the Weibull distribution of lifetime val lambda = math.exp(BLAS.dot(coefficients, features) + intercept) // shape parameter for the Weibull distribution of lifetime @@ -266,7 +293,13 @@ class AFTSurvivalRegressionModel private[ml] ( override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema) val predictUDF = udf { features: Vector => predict(features) } - dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)} + if (hasQuantilesCol) { + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + .withColumn($(quantilesCol), predictQuantilesUDF(col($(featuresCol)))) + } else { + dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } } @Since("1.6.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index ca7140a45e..359f310271 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -22,8 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.mllib.linalg.{DenseVector, Vectors} -import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -59,16 +58,20 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex assert(aftr.getFitIntercept) assert(aftr.getMaxIter === 100) assert(aftr.getTol === 1E-6) - val model = aftr.fit(datasetUnivariate) + val model = aftr.setQuantileProbabilities(Array(0.1, 0.8)) + .setQuantilesCol("quantiles") + .fit(datasetUnivariate) // copied model must have the same parent. MLTestingUtils.checkCopy(model) model.transform(datasetUnivariate) - .select("label", "prediction") + .select("label", "prediction", "quantiles") .collect() assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") + assert(model.getQuantileProbabilities === Array(0.1, 0.8)) + assert(model.getQuantilesCol === "quantiles") assert(model.intercept !== 0.0) assert(model.hasParent) } @@ -108,7 +111,10 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex } test("aft survival regression with univariate") { - val trainer = new AFTSurvivalRegression + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") val model = trainer.fit(datasetUnivariate) /* @@ -159,23 +165,25 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex [1] 0.1879174 2.6801195 14.5779394 */ val features = Vectors.dense(6.559282795753792) - val quantileProbabilities = Array(0.1, 0.5, 0.9) val responsePredictR = 4.494763 val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394) assert(model.predict(features) ~== responsePredictR relTol 1E-3) - model.setQuantileProbabilities(quantileProbabilities) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetUnivariate).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) - assert(prediction1 ~== prediction2 relTol 1E-5) + model.transform(datasetUnivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) } } test("aft survival regression with multivariate") { - val trainer = new AFTSurvivalRegression + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") val model = trainer.fit(datasetMultivariate) /* @@ -227,23 +235,26 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex [1] 0.5287044 3.3285858 10.7517072 */ val features = Vectors.dense(2.233396950271428, -2.5321374085997683) - val quantileProbabilities = Array(0.1, 0.5, 0.9) val responsePredictR = 4.761219 val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072) assert(model.predict(features) ~== responsePredictR relTol 1E-3) - model.setQuantileProbabilities(quantileProbabilities) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) - assert(prediction1 ~== prediction2 relTol 1E-5) + model.transform(datasetMultivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) } } test("aft survival regression w/o intercept") { - val trainer = new AFTSurvivalRegression().setFitIntercept(false) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + .setFitIntercept(false) val model = trainer.fit(datasetMultivariate) /* @@ -294,18 +305,31 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex [1] 1.452103 25.506077 158.428600 */ val features = Vectors.dense(2.233396950271428, -2.5321374085997683) - val quantileProbabilities = Array(0.1, 0.5, 0.9) val responsePredictR = 44.54465 val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600) assert(model.predict(features) ~== responsePredictR relTol 1E-3) - model.setQuantileProbabilities(quantileProbabilities) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) - assert(prediction1 ~== prediction2 relTol 1E-5) + model.transform(datasetMultivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) + } + } + + test("aft survival regression w/o quantiles column") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetUnivariate) + val outputDf = model.transform(datasetUnivariate) + + assert(outputDf.schema.fieldNames.contains("quantiles") === false) + + outputDf.select("features", "prediction") + .collect().foreach { + case Row(features: Vector, prediction: Double) => + assert(prediction ~== model.predict(features) relTol 1E-5) } } } |