diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 4 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala | 2 |
2 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 21fbe38ca8..a460262b87 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -468,7 +468,9 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent) + val newModel = copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) + newModel.setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 884003eb38..e4602d36cc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -310,7 +310,7 @@ class LinearRegressionModel private[ml] ( } override def copy(extra: ParamMap): LinearRegressionModel = { - val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) + val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept), extra) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) newModel.setParent(parent) } |