aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala40
2 files changed, 43 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 0d3181d0ac..7a6ab618a1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -69,7 +69,8 @@ private[r] object GeneralizedLinearRegressionWrapper
link: String,
tol: Double,
maxIter: Int,
- weightCol: String): GeneralizedLinearRegressionWrapper = {
+ weightCol: String,
+ regParam: Double): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
val rFormulaModel = rFormula.fit(data)
@@ -86,6 +87,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setTol(tol)
.setMaxIter(maxIter)
.setWeightCol(weightCol)
+ .setRegParam(regParam)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index a4568e83fa..d8032c4e17 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -1034,6 +1034,46 @@ class GeneralizedLinearRegressionSuite
.setFamily("gaussian")
.fit(datasetGaussianIdentity.as[LabeledPoint])
}
+
+ test("generalized linear regression: regularization parameter") {
+ /*
+ R code:
+
+ a1 <- c(0, 1, 2, 3)
+ a2 <- c(5, 2, 1, 3)
+ b <- c(1, 0, 1, 0)
+ data <- as.data.frame(cbind(a1, a2, b))
+ df <- suppressWarnings(createDataFrame(data))
+
+ for (regParam in c(0.0, 0.1, 1.0)) {
+ model <- spark.glm(df, b ~ a1 + a2, regParam = regParam)
+ print(as.vector(summary(model)$aic))
+ }
+
+ [1] 12.88188
+ [1] 12.92681
+ [1] 13.32836
+ */
+ val dataset = spark.createDataFrame(Seq(
+ LabeledPoint(1, Vectors.dense(5, 0)),
+ LabeledPoint(0, Vectors.dense(2, 1)),
+ LabeledPoint(1, Vectors.dense(1, 2)),
+ LabeledPoint(0, Vectors.dense(3, 3))
+ ))
+ val expected = Seq(12.88188, 12.92681, 13.32836)
+
+ var idx = 0
+ for (regParam <- Seq(0.0, 0.1, 1.0)) {
+ val trainer = new GeneralizedLinearRegression()
+ .setRegParam(regParam)
+ .setLabelCol("label")
+ .setFeaturesCol("features")
+ val model = trainer.fit(dataset)
+ val actual = model.summary.aic
+ assert(actual ~= expected(idx) absTol 1e-4, "Model mismatch: GLM with regParam = $regParam.")
+ idx += 1
+ }
+ }
}
object GeneralizedLinearRegressionSuite {