aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-08 11:11:35 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-08 11:11:35 -0700
commit5b2192e846b843d8a0cb9427d19bb677431194a0 (patch)
treec7a88191f6f2d040d5527e5387ce558624776106 /mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
parent990c9f79c28db501018a0a3af446ff879962475d (diff)
downloadspark-5b2192e846b843d8a0cb9427d19bb677431194a0.tar.gz
spark-5b2192e846b843d8a0cb9427d19bb677431194a0.tar.bz2
spark-5b2192e846b843d8a0cb9427d19bb677431194a0.zip
[SPARK-10480] [ML] Fix ML.LinearRegressionModel.copy()
This PR fix two model ```copy()``` related issues: [SPARK-10480](https://issues.apache.org/jira/browse/SPARK-10480) ```ML.LinearRegressionModel.copy()``` ignored argument ```extra```, it will not take effect when users setting this parameter. [SPARK-10479](https://issues.apache.org/jira/browse/SPARK-10479) ```ML.LogisticRegressionModel.copy()``` should copy model summary if available. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8641 from yanboliang/linear-regression-copy.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala4
1 files changed, 3 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 21fbe38ca8..a460262b87 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -468,7 +468,9 @@ class LogisticRegressionModel private[ml] (
}
override def copy(extra: ParamMap): LogisticRegressionModel = {
- copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent)
+ val newModel = copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
+ if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
+ newModel.setParent(parent)
}
override protected def raw2prediction(rawPrediction: Vector): Double = {