aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-13 13:20:29 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 13:20:29 -0700
commit0d17593b32c12c3e39575430aa85cf20e56fae6a (patch)
tree26017e8e8f6f26f6f2902ee985dd278270c3af82
parentb0adb9f543fbac16ea14c64eef6ba032a9919039 (diff)
downloadspark-0d17593b32c12c3e39575430aa85cf20e56fae6a.tar.gz
spark-0d17593b32c12c3e39575430aa85cf20e56fae6a.tar.bz2
spark-0d17593b32c12c3e39575430aa85cf20e56fae6a.zip
[SPARK-14461][ML] GLM training summaries should provide solver
## What changes were proposed in this pull request? GLM training summaries should provide solver. ## How was this patch tested? Unit tests. cc jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Closes #12253 from yanboliang/spark-14461.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala4
2 files changed, 11 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 00cf25dc54..e92a3e7fa1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -237,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
wlsModel.diagInvAtWA.toArray,
- 1)
+ 1,
+ getSolver)
return model.setSummary(trainingSummary)
}
@@ -257,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
irlsModel.diagInvAtWA.toArray,
- irlsModel.numIterations)
+ irlsModel.numIterations,
+ getSolver)
model.setSummary(trainingSummary)
}
@@ -781,6 +783,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
* @param model the model that should be summarized
* @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
* @param numIterations number of iterations
+ * @param solver the solver algorithm used for model training
*/
@Since("2.0.0")
@Experimental
@@ -789,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val model: GeneralizedLinearRegressionModel,
private val diagInvAtWA: Array[Double],
- @Since("2.0.0") val numIterations: Int) extends Serializable {
+ @Since("2.0.0") val numIterations: Int,
+ @Since("2.0.0") val solver: String) extends Serializable {
import GeneralizedLinearRegression._
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 4905f3e068..3ecc210abd 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -626,6 +626,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
+ assert(summary.solver === "irls")
}
test("glm summary: binomial family with weight") {
@@ -739,6 +740,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
+ assert(summary.solver === "irls")
}
test("glm summary: poisson family with weight") {
@@ -855,6 +857,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
+ assert(summary.solver === "irls")
}
test("glm summary: gamma family with weight") {
@@ -968,6 +971,7 @@ class GeneralizedLinearRegressionSuite
assert(summary.residualDegreeOfFreedom === residualDegreeOfFreedomR)
assert(summary.residualDegreeOfFreedomNull === residualDegreeOfFreedomNullR)
assert(summary.aic ~== aicR absTol 1E-3)
+ assert(summary.solver === "irls")
}
test("read/write") {