aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-08 11:11:35 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-08 11:11:35 -0700
commit5b2192e846b843d8a0cb9427d19bb677431194a0 (patch)
treec7a88191f6f2d040d5527e5387ce558624776106 /mllib
parent990c9f79c28db501018a0a3af446ff879962475d (diff)
downloadspark-5b2192e846b843d8a0cb9427d19bb677431194a0.tar.gz
spark-5b2192e846b843d8a0cb9427d19bb677431194a0.tar.bz2
spark-5b2192e846b843d8a0cb9427d19bb677431194a0.zip
[SPARK-10480] [ML] Fix ML.LinearRegressionModel.copy()
This PR fix two model ```copy()``` related issues: [SPARK-10480](https://issues.apache.org/jira/browse/SPARK-10480) ```ML.LinearRegressionModel.copy()``` ignored argument ```extra```, it will not take effect when users setting this parameter. [SPARK-10479](https://issues.apache.org/jira/browse/SPARK-10479) ```ML.LogisticRegressionModel.copy()``` should copy model summary if available. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8641 from yanboliang/linear-regression-copy.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala2
2 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 21fbe38ca8..a460262b87 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -468,7 +468,9 @@ class LogisticRegressionModel private[ml] (
}
override def copy(extra: ParamMap): LogisticRegressionModel = {
- copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent)
+ val newModel = copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
+ if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
+ newModel.setParent(parent)
}
override protected def raw2prediction(rawPrediction: Vector): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 884003eb38..e4602d36cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -310,7 +310,7 @@ class LinearRegressionModel private[ml] (
}
override def copy(extra: ParamMap): LinearRegressionModel = {
- val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept))
+ val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept), extra)
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
newModel.setParent(parent)
}