diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index dbd752d2aa..76891ad562 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -33,6 +33,7 @@ class AFTSurvivalRegressionSuite @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ + @transient var datasetUnivariateScaled: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() @@ -42,6 +43,11 @@ class AFTSurvivalRegressionSuite datasetMultivariate = sqlContext.createDataFrame( sc.parallelize(generateAFTInput( 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + datasetUnivariateScaled = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => + AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) + }) } /** @@ -347,6 +353,31 @@ class AFTSurvivalRegressionSuite } } + test("should support all NumericType labels") { + val aft = new AFTSurvivalRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( + aft, isClassification = false, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + } + + test("numerical stability of standardization") { + val trainer = new AFTSurvivalRegression() + val model1 = trainer.fit(datasetUnivariate) + val model2 = trainer.fit(datasetUnivariateScaled) + + /** + * During training we standardize the dataset first, so no matter how we multiple + * a scaling factor into the dataset, the convergence rate should be the same, + * and the coefficients should equal to the original coefficients multiple by + * the scaling factor. It will have no effect on the intercept and scale. + */ + assert(model1.coefficients(0) ~== model2.coefficients(0) * 1.0E3 absTol 0.01) + assert(model1.intercept ~== model2.intercept absTol 0.01) + assert(model1.scale ~== model2.scale absTol 0.01) + } + test("read/write") { def checkModelData( model: AFTSurvivalRegressionModel, |