aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala31
1 files changed, 31 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index dbd752d2aa..76891ad562 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -33,6 +33,7 @@ class AFTSurvivalRegressionSuite
@transient var datasetUnivariate: DataFrame = _
@transient var datasetMultivariate: DataFrame = _
+ @transient var datasetUnivariateScaled: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -42,6 +43,11 @@ class AFTSurvivalRegressionSuite
datasetMultivariate = sqlContext.createDataFrame(
sc.parallelize(generateAFTInput(
2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0)))
+ datasetUnivariateScaled = sqlContext.createDataFrame(
+ sc.parallelize(generateAFTInput(
+ 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x =>
+ AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor)
+ })
}
/**
@@ -347,6 +353,31 @@ class AFTSurvivalRegressionSuite
}
}
+ test("should support all NumericType labels") {
+ val aft = new AFTSurvivalRegression().setMaxIter(1)
+ MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression](
+ aft, isClassification = false, sqlContext) { (expected, actual) =>
+ assert(expected.intercept === actual.intercept)
+ assert(expected.coefficients === actual.coefficients)
+ }
+ }
+
+ test("numerical stability of standardization") {
+ val trainer = new AFTSurvivalRegression()
+ val model1 = trainer.fit(datasetUnivariate)
+ val model2 = trainer.fit(datasetUnivariateScaled)
+
+ /**
+ * During training we standardize the dataset first, so no matter how we multiple
+ * a scaling factor into the dataset, the convergence rate should be the same,
+ * and the coefficients should equal to the original coefficients multiple by
+ * the scaling factor. It will have no effect on the intercept and scale.
+ */
+ assert(model1.coefficients(0) ~== model2.coefficients(0) * 1.0E3 absTol 0.01)
+ assert(model1.intercept ~== model2.intercept absTol 0.01)
+ assert(model1.scale ~== model2.scale absTol 0.01)
+ }
+
test("read/write") {
def checkModelData(
model: AFTSurvivalRegressionModel,