diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-04-13 13:20:29 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-13 13:20:29 -0700 |
commit | 0d17593b32c12c3e39575430aa85cf20e56fae6a (patch) | |
tree | 26017e8e8f6f26f6f2902ee985dd278270c3af82 /mllib/src/main/scala/org/apache | |
parent | b0adb9f543fbac16ea14c64eef6ba032a9919039 (diff) | |
download | spark-0d17593b32c12c3e39575430aa85cf20e56fae6a.tar.gz spark-0d17593b32c12c3e39575430aa85cf20e56fae6a.tar.bz2 spark-0d17593b32c12c3e39575430aa85cf20e56fae6a.zip |
[SPARK-14461][ML] GLM training summaries should provide solver
## What changes were proposed in this pull request?
GLM training summaries should provide solver.
## How was this patch tested?
Unit tests.
cc jkbradley
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #12253 from yanboliang/spark-14461.
Diffstat (limited to 'mllib/src/main/scala/org/apache')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 00cf25dc54..e92a3e7fa1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -237,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val predictionColName, model, wlsModel.diagInvAtWA.toArray, - 1) + 1, + getSolver) return model.setSummary(trainingSummary) } @@ -257,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val predictionColName, model, irlsModel.diagInvAtWA.toArray, - irlsModel.numIterations) + irlsModel.numIterations, + getSolver) model.setSummary(trainingSummary) } @@ -781,6 +783,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr * @param model the model that should be summarized * @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration * @param numIterations number of iterations + * @param solver the solver algorithm used for model training */ @Since("2.0.0") @Experimental @@ -789,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] ( @Since("2.0.0") val predictionCol: String, @Since("2.0.0") val model: GeneralizedLinearRegressionModel, private val diagInvAtWA: Array[Double], - @Since("2.0.0") val numIterations: Int) extends Serializable { + @Since("2.0.0") val numIterations: Int, + @Since("2.0.0") val solver: String) extends Serializable { import GeneralizedLinearRegression._ |