diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala | 74 |
1 files changed, 49 insertions, 25 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index ca7140a45e..359f310271 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -22,8 +22,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.MLTestingUtils -import org.apache.spark.mllib.linalg.{DenseVector, Vectors} -import org.apache.spark.mllib.linalg.BLAS +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -59,16 +58,20 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex assert(aftr.getFitIntercept) assert(aftr.getMaxIter === 100) assert(aftr.getTol === 1E-6) - val model = aftr.fit(datasetUnivariate) + val model = aftr.setQuantileProbabilities(Array(0.1, 0.8)) + .setQuantilesCol("quantiles") + .fit(datasetUnivariate) // copied model must have the same parent. MLTestingUtils.checkCopy(model) model.transform(datasetUnivariate) - .select("label", "prediction") + .select("label", "prediction", "quantiles") .collect() assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") + assert(model.getQuantileProbabilities === Array(0.1, 0.8)) + assert(model.getQuantilesCol === "quantiles") assert(model.intercept !== 0.0) assert(model.hasParent) } @@ -108,7 +111,10 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex } test("aft survival regression with univariate") { - val trainer = new AFTSurvivalRegression + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") val model = trainer.fit(datasetUnivariate) /* @@ -159,23 +165,25 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex [1] 0.1879174 2.6801195 14.5779394 */ val features = Vectors.dense(6.559282795753792) - val quantileProbabilities = Array(0.1, 0.5, 0.9) val responsePredictR = 4.494763 val quantilePredictR = Vectors.dense(0.1879174, 2.6801195, 14.5779394) assert(model.predict(features) ~== responsePredictR relTol 1E-3) - model.setQuantileProbabilities(quantileProbabilities) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetUnivariate).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) - assert(prediction1 ~== prediction2 relTol 1E-5) + model.transform(datasetUnivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) } } test("aft survival regression with multivariate") { - val trainer = new AFTSurvivalRegression + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") val model = trainer.fit(datasetMultivariate) /* @@ -227,23 +235,26 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex [1] 0.5287044 3.3285858 10.7517072 */ val features = Vectors.dense(2.233396950271428, -2.5321374085997683) - val quantileProbabilities = Array(0.1, 0.5, 0.9) val responsePredictR = 4.761219 val quantilePredictR = Vectors.dense(0.5287044, 3.3285858, 10.7517072) assert(model.predict(features) ~== responsePredictR relTol 1E-3) - model.setQuantileProbabilities(quantileProbabilities) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) - assert(prediction1 ~== prediction2 relTol 1E-5) + model.transform(datasetMultivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) } } test("aft survival regression w/o intercept") { - val trainer = new AFTSurvivalRegression().setFitIntercept(false) + val quantileProbabilities = Array(0.1, 0.5, 0.9) + val trainer = new AFTSurvivalRegression() + .setQuantileProbabilities(quantileProbabilities) + .setQuantilesCol("quantiles") + .setFitIntercept(false) val model = trainer.fit(datasetMultivariate) /* @@ -294,18 +305,31 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContex [1] 1.452103 25.506077 158.428600 */ val features = Vectors.dense(2.233396950271428, -2.5321374085997683) - val quantileProbabilities = Array(0.1, 0.5, 0.9) val responsePredictR = 44.54465 val quantilePredictR = Vectors.dense(1.452103, 25.506077, 158.428600) assert(model.predict(features) ~== responsePredictR relTol 1E-3) - model.setQuantileProbabilities(quantileProbabilities) assert(model.predictQuantiles(features) ~== quantilePredictR relTol 1E-3) - model.transform(datasetMultivariate).select("features", "prediction").collect().foreach { - case Row(features: DenseVector, prediction1: Double) => - val prediction2 = math.exp(BLAS.dot(model.coefficients, features) + model.intercept) - assert(prediction1 ~== prediction2 relTol 1E-5) + model.transform(datasetMultivariate).select("features", "prediction", "quantiles") + .collect().foreach { + case Row(features: Vector, prediction: Double, quantiles: Vector) => + assert(prediction ~== model.predict(features) relTol 1E-5) + assert(quantiles ~== model.predictQuantiles(features) relTol 1E-5) + } + } + + test("aft survival regression w/o quantiles column") { + val trainer = new AFTSurvivalRegression + val model = trainer.fit(datasetUnivariate) + val outputDf = model.transform(datasetUnivariate) + + assert(outputDf.schema.fieldNames.contains("quantiles") === false) + + outputDf.select("features", "prediction") + .collect().foreach { + case Row(features: Vector, prediction: Double) => + assert(prediction ~== model.predict(features) relTol 1E-5) } } } |