aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala2
2 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 21fbe38ca8..a460262b87 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -468,7 +468,9 @@ class LogisticRegressionModel private[ml] (
}
override def copy(extra: ParamMap): LogisticRegressionModel = {
- copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent)
+ val newModel = copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
+ if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
+ newModel.setParent(parent)
}
override protected def raw2prediction(rawPrediction: Vector): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 884003eb38..e4602d36cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -310,7 +310,7 @@ class LinearRegressionModel private[ml] (
}
override def copy(extra: ParamMap): LinearRegressionModel = {
- val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept))
+ val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept), extra)
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
newModel.setParent(parent)
}