aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXin Ren <iamshrek@126.com>2016-08-31 21:39:31 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-08-31 21:39:31 -0700
commit7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2 (patch)
tree402f7fc410378418369b4abb4e1b3d92a1358a5c /mllib
parentd008638fbedc857c1adc1dff399d427b8bae848e (diff)
downloadspark-7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2.tar.gz
spark-7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2.tar.bz2
spark-7a5000f39ef4f195696836f8a4e8ab4ff5c14dd2.zip
[SPARK-17241][SPARKR][MLLIB] SparkR spark.glm should have configurable regularization parameter
https://issues.apache.org/jira/browse/SPARK-17241 ## What changes were proposed in this pull request? Spark has configurable L2 regularization parameter for generalized linear regression. It is very important to have them in SparkR so that users can run ridge regression. ## How was this patch tested? Test manually on local laptop. Author: Xin Ren <iamshrek@126.com> Closes #14856 from keypointt/SPARK-17241.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala40
2 files changed, 43 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 0d3181d0ac..7a6ab618a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -69,7 +69,8 @@ private[r] object GeneralizedLinearRegressionWrapper
link: String,
tol: Double,
maxIter: Int,
- weightCol: String): GeneralizedLinearRegressionWrapper = {
+ weightCol: String,
+ regParam: Double): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
val rFormulaModel = rFormula.fit(data)
@@ -86,6 +87,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setTol(tol)
.setMaxIter(maxIter)
.setWeightCol(weightCol)
+ .setRegParam(regParam)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index a4568e83fa..d8032c4e17 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1034,6 +1034,46 @@ class GeneralizedLinearRegressionSuite
.setFamily("gaussian")
.fit(datasetGaussianIdentity.as[LabeledPoint])
}
+
+ test("generalized linear regression: regularization parameter") {
+ /*
+ R code:
+
+ a1 <- c(0, 1, 2, 3)
+ a2 <- c(5, 2, 1, 3)
+ b <- c(1, 0, 1, 0)
+ data <- as.data.frame(cbind(a1, a2, b))
+ df <- suppressWarnings(createDataFrame(data))
+
+ for (regParam in c(0.0, 0.1, 1.0)) {
+ model <- spark.glm(df, b ~ a1 + a2, regParam = regParam)
+ print(as.vector(summary(model)$aic))
+ }
+
+ [1] 12.88188
+ [1] 12.92681
+ [1] 13.32836
+ */
+ val dataset = spark.createDataFrame(Seq(
+ LabeledPoint(1, Vectors.dense(5, 0)),
+ LabeledPoint(0, Vectors.dense(2, 1)),
+ LabeledPoint(1, Vectors.dense(1, 2)),
+ LabeledPoint(0, Vectors.dense(3, 3))
+ ))
+ val expected = Seq(12.88188, 12.92681, 13.32836)
+
+ var idx = 0
+ for (regParam <- Seq(0.0, 0.1, 1.0)) {
+ val trainer = new GeneralizedLinearRegression()
+ .setRegParam(regParam)
+ .setLabelCol("label")
+ .setFeaturesCol("features")
+ val model = trainer.fit(dataset)
+ val actual = model.summary.aic
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with regParam = $regParam.")
+ idx += 1
+ }
+ }
}
object GeneralizedLinearRegressionSuite {