diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-09-08 11:11:35 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-09-08 11:11:35 -0700 |
commit | 5b2192e846b843d8a0cb9427d19bb677431194a0 (patch) | |
tree | c7a88191f6f2d040d5527e5387ce558624776106 /mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | |
parent | 990c9f79c28db501018a0a3af446ff879962475d (diff) | |
download | spark-5b2192e846b843d8a0cb9427d19bb677431194a0.tar.gz spark-5b2192e846b843d8a0cb9427d19bb677431194a0.tar.bz2 spark-5b2192e846b843d8a0cb9427d19bb677431194a0.zip |
[SPARK-10480] [ML] Fix ML.LinearRegressionModel.copy()
This PR fix two model ```copy()``` related issues:
[SPARK-10480](https://issues.apache.org/jira/browse/SPARK-10480)
```ML.LinearRegressionModel.copy()``` ignored argument ```extra```, it will not take effect when users setting this parameter.
[SPARK-10479](https://issues.apache.org/jira/browse/SPARK-10479)
```ML.LogisticRegressionModel.copy()``` should copy model summary if available.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #8641 from yanboliang/linear-regression-copy.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21fbe38ca8..a460262b87 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -468,7 +468,9 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent) + val newModel = copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) + newModel.setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { |