aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-13 13:20:29 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 13:20:29 -0700
commit0d17593b32c12c3e39575430aa85cf20e56fae6a (patch)
tree26017e8e8f6f26f6f2902ee985dd278270c3af82 /mllib/src/main/scala/org/apache
parentb0adb9f543fbac16ea14c64eef6ba032a9919039 (diff)
downloadspark-0d17593b32c12c3e39575430aa85cf20e56fae6a.tar.gz
spark-0d17593b32c12c3e39575430aa85cf20e56fae6a.tar.bz2
spark-0d17593b32c12c3e39575430aa85cf20e56fae6a.zip
[SPARK-14461][ML] GLM training summaries should provide solver
## What changes were proposed in this pull request? GLM training summaries should provide solver. ## How was this patch tested? Unit tests. cc jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Closes #12253 from yanboliang/spark-14461.
Diffstat (limited to 'mllib/src/main/scala/org/apache')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala10
1 files changed, 7 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 00cf25dc54..e92a3e7fa1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -237,7 +237,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
wlsModel.diagInvAtWA.toArray,
- 1)
+ 1,
+ getSolver)
return model.setSummary(trainingSummary)
}
@@ -257,7 +258,8 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val
predictionColName,
model,
irlsModel.diagInvAtWA.toArray,
- irlsModel.numIterations)
+ irlsModel.numIterations,
+ getSolver)
model.setSummary(trainingSummary)
}
@@ -781,6 +783,7 @@ object GeneralizedLinearRegressionModel extends MLReadable[GeneralizedLinearRegr
* @param model the model that should be summarized
* @param diagInvAtWA diagonal of matrix (A^T * W * A)^-1 in the last iteration
* @param numIterations number of iterations
+ * @param solver the solver algorithm used for model training
*/
@Since("2.0.0")
@Experimental
@@ -789,7 +792,8 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0") val predictionCol: String,
@Since("2.0.0") val model: GeneralizedLinearRegressionModel,
private val diagInvAtWA: Array[Double],
- @Since("2.0.0") val numIterations: Int) extends Serializable {
+ @Since("2.0.0") val numIterations: Int,
+ @Since("2.0.0") val solver: String) extends Serializable {
import GeneralizedLinearRegression._