aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-12 11:27:16 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-12 11:27:16 -0700
commit101663f1ae222a919fc40510aa4f2bad22d1be6f (patch)
treead1288139993cb7cc07a1e366e80616f1221cf0b /mllib/src/test/scala/org/apache
parent75e05a5a964c9585dd09a2ef6178881929bab1f1 (diff)
downloadspark-101663f1ae222a919fc40510aa4f2bad22d1be6f.tar.gz
spark-101663f1ae222a919fc40510aa4f2bad22d1be6f.tar.bz2
spark-101663f1ae222a919fc40510aa4f2bad22d1be6f.zip
[SPARK-13322][ML] AFTSurvivalRegression supports feature standardization
## What changes were proposed in this pull request? AFTSurvivalRegression should support feature standardization, it will improve the convergence rate. Test the convergence rate on the [Ovarian](https://stat.ethz.ch/R-manual/R-devel/library/survival/html/ovarian.html) data which is standard data comes with Survival library in R, * without standardization(before this PR) -> 74 iterations. * with standardization(after this PR) -> 38 iterations. But after this fix, with or without ```standardization``` will converge to the same solution. It means that ```standardization = false``` will run the same code route as ```standardization = true```. Because if the features are not standardized at all, it will result convergency issue when the features have very different scales. This behavior is the same as ML [```LinearRegression``` and ```LogisticRegression```](https://issues.apache.org/jira/browse/SPARK-8522). See more discussion about this topic at #11247. cc mengxr ## How was this patch tested? unit test. Author: Yanbo Liang <ybliang8@gmail.com> Closes #11365 from yanboliang/spark-13322.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala22
1 files changed, 22 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
index f4844cc671..76891ad562 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala
@@ -33,6 +33,7 @@ class AFTSurvivalRegressionSuite
@transient var datasetUnivariate: DataFrame = _
@transient var datasetMultivariate: DataFrame = _
+ @transient var datasetUnivariateScaled: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -42,6 +43,11 @@ class AFTSurvivalRegressionSuite
datasetMultivariate = sqlContext.createDataFrame(
sc.parallelize(generateAFTInput(
2, Array(0.9, -1.3), Array(0.7, 1.2), 1000, 42, 1.5, 2.5, 2.0)))
+ datasetUnivariateScaled = sqlContext.createDataFrame(
+ sc.parallelize(generateAFTInput(
+ 1, Array(5.5), Array(0.8), 1000, 42, 1.0, 2.0, 2.0)).map { x =>
+ AFTPoint(Vectors.dense(x.features(0) * 1.0E3), x.label, x.censor)
+ })
}
/**
@@ -356,6 +362,22 @@ class AFTSurvivalRegressionSuite
}
}
+ test("numerical stability of standardization") {
+ val trainer = new AFTSurvivalRegression()
+ val model1 = trainer.fit(datasetUnivariate)
+ val model2 = trainer.fit(datasetUnivariateScaled)
+
+ /**
+ * During training we standardize the dataset first, so no matter how we multiple
+ * a scaling factor into the dataset, the convergence rate should be the same,
+ * and the coefficients should equal to the original coefficients multiple by
+ * the scaling factor. It will have no effect on the intercept and scale.
+ */
+ assert(model1.coefficients(0) ~== model2.coefficients(0) * 1.0E3 absTol 0.01)
+ assert(model1.intercept ~== model2.intercept absTol 0.01)
+ assert(model1.scale ~== model2.scale absTol 0.01)
+ }
+
test("read/write") {
def checkModelData(
model: AFTSurvivalRegressionModel,