aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-23 15:26:02 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-23 15:26:02 -0700
commitce2b056d35c0c75d5c162b93680ee2d84152e911 (patch)
tree66151913c8efefea478dfb6f8c17067b56e1b673 /mllib/src/test/scala/org/apache
parent098be27ad53c485ee2fc7f5871c47f899020e87b (diff)
downloadspark-ce2b056d35c0c75d5c162b93680ee2d84152e911.tar.gz
spark-ce2b056d35c0c75d5c162b93680ee2d84152e911.tar.bz2
spark-ce2b056d35c0c75d5c162b93680ee2d84152e911.zip
[SPARK-10686] [ML] Add quantilesCol to AFTSurvivalRegression
By default ```quantilesCol``` should be empty. If ```quantileProbabilities``` is set, we should append quantiles as a new column (of type Vector). Author: Yanbo Liang <ybliang8@gmail.com> Closes #8836 from yanboliang/spark-10686.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala74
1 files changed, 49 insertions, 25 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index ca7140a45e..359f310271 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -22,8 +22,7 @@ import scala.util.Random
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
-import org.apache.spark.mllib.linalg.{DenseVector, Vectors}
-import org.apache.spark.mllib.linalg.BLAS
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -59,16 +58,20 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
assert(aftr.getFitIntercept)
assert(aftr.getMaxIter === 100)
assert(aftr.getTol === 1E-6)
- val model = aftr.fit(datasetUnivariate)
+ val model = aftr.setQuantileProbabilities(Array(0.1, 0.8))
+ .setQuantilesCol("quantiles")
+ .fit(datasetUnivariate)
// copied model must have the same parent.
MLTestingUtils.checkCopy(model)
model.transform(datasetUnivariate)
- .select("label", "prediction")
+ .select("label", "prediction", "quantiles")
.collect()
assert(model.getFeaturesCol === "features")
assert(model.getPredictionCol === "prediction")
+ assert(model.getQuantileProbabilities === Array(0.1, 0.8))
+ assert(model.getQuantilesCol === "quantiles")
assert(model.intercept !== 0.0)
assert(model.hasParent)
}
@@ -108,7 +111,10 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
}
test("aft survival regression with univariate") {
- val trainer = new AFTSurvivalRegression
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val trainer = new AFTSurvivalRegression()
+ .setQuantileProbabilities(quantileProbabilities)
+ .setQuantilesCol("quantiles")
val model = trainer.fit(datasetUnivariate)
/*
@@ -159,23 +165,25 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
[1] 0.1879174 2.6801195 14.5779394
*/
val features = Vectors.dense(6.559282795753792)
- val quantileProbabilities = Array(0.1, 0.5, 0.9)
val responsePredictR = 4.494763
val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394)
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
- model.setQuantileProbabilities(quantileProbabilities)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
- model.transform(datasetUnivariate).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ model.transform(datasetUnivariate).select("features", "prediction", "quantiles")
+ .collect().foreach {
+ case Row(features: Vector, prediction: Double, quantiles: Vector) =>
+ assert(prediction ~== model.predict(features) relTol 1E-5)
+ assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
}
}
test("aft survival regression with multivariate") {
- val trainer = new AFTSurvivalRegression
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val trainer = new AFTSurvivalRegression()
+ .setQuantileProbabilities(quantileProbabilities)
+ .setQuantilesCol("quantiles")
val model = trainer.fit(datasetMultivariate)
/*
@@ -227,23 +235,26 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
[1] 0.5287044 3.3285858 10.7517072
*/
val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
- val quantileProbabilities = Array(0.1, 0.5, 0.9)
val responsePredictR = 4.761219
val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072)
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
- model.setQuantileProbabilities(quantileProbabilities)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
- model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ model.transform(datasetMultivariate).select("features", "prediction", "quantiles")
+ .collect().foreach {
+ case Row(features: Vector, prediction: Double, quantiles: Vector) =>
+ assert(prediction ~== model.predict(features) relTol 1E-5)
+ assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
}
}
test("aft survival regression w/o intercept") {
- val trainer = new AFTSurvivalRegression().setFitIntercept(false)
+ val quantileProbabilities = Array(0.1, 0.5, 0.9)
+ val trainer = new AFTSurvivalRegression()
+ .setQuantileProbabilities(quantileProbabilities)
+ .setQuantilesCol("quantiles")
+ .setFitIntercept(false)
val model = trainer.fit(datasetMultivariate)
/*
@@ -294,18 +305,31 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex
[1] 1.452103 25.506077 158.428600
*/
val features = Vectors.dense(2.233396950271428, -2.5321374085997683)
- val quantileProbabilities = Array(0.1, 0.5, 0.9)
val responsePredictR = 44.54465
val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600)
assert(model.predict(features) ~== responsePredictR relTol 1E-3)
- model.setQuantileProbabilities(quantileProbabilities)
assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3)
- model.transform(datasetMultivariate).select("features", "prediction").collect().foreach {
- case Row(features: DenseVector, prediction1: Double) =>
- val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept)
- assert(prediction1 ~== prediction2 relTol 1E-5)
+ model.transform(datasetMultivariate).select("features", "prediction", "quantiles")
+ .collect().foreach {
+ case Row(features: Vector, prediction: Double, quantiles: Vector) =>
+ assert(prediction ~== model.predict(features) relTol 1E-5)
+ assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5)
+ }
+ }
+
+ test("aft survival regression w/o quantiles column") {
+ val trainer = new AFTSurvivalRegression
+ val model = trainer.fit(datasetUnivariate)
+ val outputDf = model.transform(datasetUnivariate)
+
+ assert(outputDf.schema.fieldNames.contains("quantiles") === false)
+
+ outputDf.select("features", "prediction")
+ .collect().foreach {
+ case Row(features: Vector, prediction: Double) =>
+ assert(prediction ~== model.predict(features) relTol 1E-5)
}
}
}