diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-04-12 11:27:16 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-12 11:27:16 -0700 |
commit | 101663f1ae222a919fc40510aa4f2bad22d1be6f (patch) | |
tree | ad1288139993cb7cc07a1e366e80616f1221cf0b /mllib/src/test/scala/org/apache | |
parent | 75e05a5a964c9585dd09a2ef6178881929bab1f1 (diff) | |
download | spark-101663f1ae222a919fc40510aa4f2bad22d1be6f.tar.gz spark-101663f1ae222a919fc40510aa4f2bad22d1be6f.tar.bz2 spark-101663f1ae222a919fc40510aa4f2bad22d1be6f.zip |
[SPARK-13322][ML] AFTSurvivalRegression supports feature standardization
## What changes were proposed in this pull request?
AFTSurvivalRegression should support feature standardization, it will improve the convergence rate.
Test the convergence rate on the [Ovarian](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/ovarian.html) data which is standard data comes with Survival library in R,
* without standardization(before this PR) -> 74 iterations.
* with standardization(after this PR) -> 38 iterations.
But after this fix, with or without ```standardization``` will converge to the same solution. It means that ```standardization = false``` will run the same code route as ```standardization = true```. Because if the features are not standardized at all, it will result convergency issue when the features have very different scales. This behavior is the same as ML [```LinearRegression``` and ```LogisticRegression```](https://issues.apache.org/jira/browse/SPARK-8522). See more discussion about this topic at #11247.
cc mengxr
## How was this patch tested?
unit test.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #11365 from yanboliang/spark-13322.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index f4844cc671..76891ad562 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -33,6 +33,7 @@ class AFTSurvivalRegressionSuite @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ + @transient var datasetUnivariateScaled: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() @@ -42,6 +43,11 @@ class AFTSurvivalRegressionSuite datasetMultivariate = sqlContext.createDataFrame( sc.parallelize(generateAFTInput( 2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0))) + datasetUnivariateScaled = sqlContext.createDataFrame( + sc.parallelize(generateAFTInput( + 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x => + AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor) + }) } /** @@ -356,6 +362,22 @@ class AFTSurvivalRegressionSuite } } + test("numerical stability of standardization") { + val trainer = new AFTSurvivalRegression() + val model1 = trainer.fit(datasetUnivariate) + val model2 = trainer.fit(datasetUnivariateScaled) + + /** + * During training we standardize the dataset first, so no matter how we multiple + * a scaling factor into the dataset, the convergence rate should be the same, + * and the coefficients should equal to the original coefficients multiple by + * the scaling factor. It will have no effect on the intercept and scale. + */ + assert(model1.coefficients(0) ~== model2.coefficients(0) * 1.0E3 absTol 0.01) + assert(model1.intercept ~== model2.intercept absTol 0.01) + assert(model1.scale ~== model2.scale absTol 0.01) + } + test("read/write") { def checkModelData( model: AFTSurvivalRegressionModel, |