aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-11-05 22:38:07 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-11-05 22:38:07 -0700
commit23ce0d1e91076d90c1a87d698a94d283d08cf899 (patch)
tree1224cab695176c4a59b841d08111b099aee7e2d4 /mllib/src/main
parent15d392688456ad9f963417843c52a7b610f771d2 (diff)
downloadspark-23ce0d1e91076d90c1a87d698a94d283d08cf899.tar.gz
spark-23ce0d1e91076d90c1a87d698a94d283d08cf899.tar.bz2
spark-23ce0d1e91076d90c1a87d698a94d283d08cf899.zip
[SPARK-18276][ML] ML models should copy the training summary and set parent
## What changes were proposed in this pull request? Only some of the models which contain a training summary currently set the summaries in the copy method. Linear/Logistic regression do, GLR, GMM, KM, and BKM do not. Additionally, these copy methods did not set the parent pointer of the copied model. This patch modifies the copy methods of the four models mentioned above to copy the training summary and set the parent. ## How was this patch tested? Add unit tests in Linear/Logistic/GeneralizedLinear regression and GaussianMixture/KMeans/BisectingKMeans to check the parent pointer of the copied model and check that the copied model has a summary. Author: sethah <seth.hendrickson16@gmail.com> Closes #15773 from sethah/SPARK-18276.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala2
5 files changed, 14 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index 2718dd93dc..f8a606d60b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -94,8 +94,9 @@ class BisectingKMeansModel private[ml] (
@Since("2.0.0")
override def copy(extra: ParamMap): BisectingKMeansModel = {
- val copied = new BisectingKMeansModel(uid, parentModel)
- copyValues(copied, extra)
+ val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra)
+ if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+ copied.setParent(this.parent)
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 8fac63fefb..a0bd66e731 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -89,8 +89,9 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
override def copy(extra: ParamMap): GaussianMixtureModel = {
- val copied = new GaussianMixtureModel(uid, weights, gaussians)
- copyValues(copied, extra).setParent(this.parent)
+ val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra)
+ if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+ copied.setParent(this.parent)
}
@Since("2.0.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 85bb8c93b3..a0d481b294 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -108,8 +108,9 @@ class KMeansModel private[ml] (
@Since("1.5.0")
override def copy(extra: ParamMap): KMeansModel = {
- val copied = new KMeansModel(uid, parentModel)
- copyValues(copied, extra)
+ val copied = copyValues(new KMeansModel(uid, parentModel), extra)
+ if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+ copied.setParent(this.parent)
}
/** @group setParam */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 8656ecf609..1938e8ecc5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -776,8 +776,10 @@ class GeneralizedLinearRegressionModel private[ml] (
@Since("2.0.0")
override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = {
- copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra)
- .setParent(parent)
+ val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept),
+ extra)
+ if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get)
+ copied.setParent(parent)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 0fdba1cb88..5d1a39f7c1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -221,7 +221,7 @@ class TrainValidationSplitModel private[ml] (
uid,
bestModel.copy(extra).asInstanceOf[Model[_]],
validationMetrics.clone())
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
@Since("2.0.0")