aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-23 15:26:02 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-23 15:26:02 -0700
commitce2b056d35c0c75d5c162b93680ee2d84152e911 (patch)
tree66151913c8efefea478dfb6f8c17067b56e1b673 /mllib
parent098be27ad53c485ee2fc7f5871c47f899020e87b (diff)
downloadspark-ce2b056d35c0c75d5c162b93680ee2d84152e911.tar.gz
spark-ce2b056d35c0c75d5c162b93680ee2d84152e911.tar.bz2
spark-ce2b056d35c0c75d5c162b93680ee2d84152e911.zip
[SPARK-10686] [ML] Add quantilesCol to AFTSurvivalRegression
By default ```quantilesCol``` should be empty. If ```quantileProbabilities``` is set, we should append quantiles as a new column (of type Vector). Author: Yanbo Liang <ybliang8@gmail.com> Closes #8836 from yanboliang/spark-10686.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala51
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala74
2 files changed, 91 insertions, 34 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 5b25db651f..717caacad3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -41,7 +41,7 @@ import org.apache.spark.storage.StorageLevel
*/
private[regression] trait AFTSurvivalRegressionParams extends Params
with HasFeaturesCol with HasLabelCol with HasPredictionCol with HasMaxIter
- with HasTol with HasFitIntercept {
+ with HasTol with HasFitIntercept with Logging {
/**
* Param for censor column name.
@@ -59,21 +59,35 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
/**
* Param for quantile probabilities array.
- * Values of the quantile probabilities array should be in the range [0, 1].
+ * Values of the quantile probabilities array should be in the range [0, 1]
+ * and the array should be non-empty.
* @group param
*/
@Since("1.6.0")
final val quantileProbabilities: DoubleArrayParam = new DoubleArrayParam(this,
"quantileProbabilities", "quantile probabilities array",
- (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1)))
+ (t: Array[Double]) => t.forall(ParamValidators.inRange(0, 1)) && t.length > 0)
/** @group getParam */
@Since("1.6.0")
def getQuantileProbabilities: Array[Double] = $(quantileProbabilities)
+ setDefault(quantileProbabilities -> Array(0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99))
- /** Checks whether the input has quantile probabilities array. */
- protected[regression] def hasQuantileProbabilities: Boolean = {
- isDefined(quantileProbabilities) && $(quantileProbabilities).size != 0
+ /**
+ * Param for quantiles column name.
+ * This column will output quantiles of corresponding quantileProbabilities if it is set.
+ * @group param
+ */
+ @Since("1.6.0")
+ final val quantilesCol: Param[String] = new Param(this, "quantilesCol", "quantiles column name")
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getQuantilesCol: String = $(quantilesCol)
+
+ /** Checks whether the input has quantiles column name. */
+ protected[regression] def hasQuantilesCol: Boolean = {
+ isDefined(quantilesCol) && $(quantilesCol) != ""
}
/**
@@ -90,6 +104,9 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
}
+ if (hasQuantilesCol) {
+ SchemaUtils.appendColumn(schema, $(quantilesCol), new VectorUDT)
+ }
SchemaUtils.appendColumn(schema, $(predictionCol), DoubleType)
}
}
@@ -124,6 +141,14 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S
@Since("1.6.0")
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ /** @group setParam */
+ @Since("1.6.0")
+ def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
+
+ /** @group setParam */
+ @Since("1.6.0")
+ def setQuantilesCol(value: String): this.type = set(quantilesCol, value)
+
/**
* Set if we should fit the intercept
* Default is true.
@@ -243,10 +268,12 @@ class AFTSurvivalRegressionModel private[ml] (
@Since("1.6.0")
def setQuantileProbabilities(value: Array[Double]): this.type = set(quantileProbabilities, value)
+ /** @group setParam */
+ @Since("1.6.0")
+ def setQuantilesCol(value: String): this.type = set(quantilesCol, value)
+
@Since("1.6.0")
def predictQuantiles(features: Vector): Vector = {
- require(hasQuantileProbabilities,
- "AFTSurvivalRegressionModel predictQuantiles must set quantile probabilities array")
// scale parameter for the Weibull distribution of lifetime
val lambda = math.exp(BLAS.dot(coefficients, features) + intercept)
// shape parameter for the Weibull distribution of lifetime
@@ -266,7 +293,13 @@ class AFTSurvivalRegressionModel private[ml] (
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema)
val predictUDF = udf { features: Vector => predict(features) }
- dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ val predictQuantilesUDF = udf { features: Vector => predictQuantiles(features)}
+ if (hasQuantilesCol) {
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ .withColumn($(quantilesCol), predictQuantilesUDF(col($(featuresCol))))
+ } else {
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
}
@Since("1.6.0")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index ca7140a45e..359f310271 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -22,8 +22,7 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
-import org.apache.spark.mllib.linalg.BLAS
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -59,16 +58,20 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
assert(aftr.getFitIntercept)
assert(aftr.getMaxIter === 100)
assert(aftr.getTol === 1E-6)
- val model = aftr.fit(datasetUnivariate)
+ val model = aftr.setQuantileProbabilities(Array(0.1, 0.8))
+ .setQuantilesCol("quantiles")
+ .fit(datasetUnivariate)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
model.transform(datasetUnivariate)
- .select("label", "prediction")
+ .select("label", "prediction", "quantiles")
.collect()
assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
+ assert(model.getQuantileProbabilities === Array(0.1, 0.8))
+ assert(model.getQuantilesCol === "quantiles")
assert(model.intercept !== 0.0)
assert(model.hasParent)
}
@@ -108,7 +111,10 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
}
test("aft survival regression with univariate") {
- val trainer = new AFTSurvivalRegression
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val trainer = new AFTSurvivalRegression()
+ .setQuantileProbabilities(quantileProbabilities)
+ .setQuantilesCol("quantiles")
val model = trainer.fit(datasetUnivariate)
/*
@@ -159,23 +165,25 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
[1] 0.1879174 2.6801195 14.5779394
*/
val features = Vectors.dense(6.559282795753792)
- val quantileProbabilities = Array(0.1, 0.5, 0.9)
val responsePredictR = 4.494763
val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394)
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
- model.setQuantileProbabilities(quantileProbabilities)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
- model.transform(datasetUnivariate).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ model.transform(datasetUnivariate).select("features", "prediction", "quantiles")
+ .collect().foreach {
+ case Row(features: Vector, prediction: Double, quantiles: Vector) =>
+ assert(prediction ~== model.predict(features) relTol 1E-5)
+ assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
}
}
test("aft survival regression with multivariate") {
- val trainer = new AFTSurvivalRegression
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val trainer = new AFTSurvivalRegression()
+ .setQuantileProbabilities(quantileProbabilities)
+ .setQuantilesCol("quantiles")
val model = trainer.fit(datasetMultivariate)
/*
@@ -227,23 +235,26 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
[1] 0.5287044 3.3285858 10.7517072
*/
val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
- val quantileProbabilities = Array(0.1, 0.5, 0.9)
val responsePredictR = 4.761219
val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072)
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
- model.setQuantileProbabilities(quantileProbabilities)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
- model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ model.transform(datasetMultivariate).select("features", "prediction", "quantiles")
+ .collect().foreach {
+ case Row(features: Vector, prediction: Double, quantiles: Vector) =>
+ assert(prediction ~== model.predict(features) relTol 1E-5)
+ assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
}
}
test("aft survival regression w/o intercept") {
- val trainer = new AFTSurvivalRegression().setFitIntercept(false)
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val trainer = new AFTSurvivalRegression()
+ .setQuantileProbabilities(quantileProbabilities)
+ .setQuantilesCol("quantiles")
+ .setFitIntercept(false)
val model = trainer.fit(datasetMultivariate)
/*
@@ -294,18 +305,31 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
[1] 1.452103 25.506077 158.428600
*/
val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
- val quantileProbabilities = Array(0.1, 0.5, 0.9)
val responsePredictR = 44.54465
val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600)
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
- model.setQuantileProbabilities(quantileProbabilities)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
- model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ model.transform(datasetMultivariate).select("features", "prediction", "quantiles")
+ .collect().foreach {
+ case Row(features: Vector, prediction: Double, quantiles: Vector) =>
+ assert(prediction ~== model.predict(features) relTol 1E-5)
+ assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
+ }
+ }
+
+ test("aft survival regression w/o quantiles column") {
+ val trainer = new AFTSurvivalRegression
+ val model = trainer.fit(datasetUnivariate)
+ val outputDf = model.transform(datasetUnivariate)
+
+ assert(outputDf.schema.fieldNames.contains("quantiles") === false)
+
+ outputDf.select("features", "prediction")
+ .collect().foreach {
+ case Row(features: Vector, prediction: Double) =>
+ assert(prediction ~== model.predict(features) relTol 1E-5)
}
}
}