diff options
author | sethah <seth.hendrickson16@gmail.com> | 2016-11-05 22:38:07 -0700 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-11-05 22:38:07 -0700 |
commit | 23ce0d1e91076d90c1a87d698a94d283d08cf899 (patch) | |
tree | 1224cab695176c4a59b841d08111b099aee7e2d4 | |
parent | 15d392688456ad9f963417843c52a7b610f771d2 (diff) | |
download | spark-23ce0d1e91076d90c1a87d698a94d283d08cf899.tar.gz spark-23ce0d1e91076d90c1a87d698a94d283d08cf899.tar.bz2 spark-23ce0d1e91076d90c1a87d698a94d283d08cf899.zip |
[SPARK-18276][ML] ML models should copy the training summary and set parent
## What changes were proposed in this pull request?
Only some of the models which contain a training summary currently set the summaries in the copy method. Linear/Logistic regression do, GLR, GMM, KM, and BKM do not. Additionally, these copy methods did not set the parent pointer of the copied model. This patch modifies the copy methods of the four models mentioned above to copy the training summary and set the parent.
## How was this patch tested?
Add unit tests in Linear/Logistic/GeneralizedLinear regression and GaussianMixture/KMeans/BisectingKMeans to check the parent pointer of the copied model and check that the copied model has a summary.
Author: sethah <seth.hendrickson16@gmail.com>
Closes #15773 from sethah/SPARK-18276.
12 files changed, 62 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala index 2718dd93dc..f8a606d60b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala @@ -94,8 +94,9 @@ class BisectingKMeansModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): BisectingKMeansModel = { - val copied = new BisectingKMeansModel(uid, parentModel) - copyValues(copied, extra) + val copied = copyValues(new BisectingKMeansModel(uid, parentModel), extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(this.parent) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index 8fac63fefb..a0bd66e731 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -89,8 +89,9 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GaussianMixtureModel = { - val copied = new GaussianMixtureModel(uid, weights, gaussians) - copyValues(copied, extra).setParent(this.parent) + val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(this.parent) } @Since("2.0.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 85bb8c93b3..a0d481b294 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -108,8 +108,9 @@ class KMeansModel private[ml] ( @Since("1.5.0") override def copy(extra: ParamMap): KMeansModel = { - val copied = new KMeansModel(uid, parentModel) - copyValues(copied, extra) + val copied = copyValues(new KMeansModel(uid, parentModel), extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(this.parent) } /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 8656ecf609..1938e8ecc5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -776,8 +776,10 @@ class GeneralizedLinearRegressionModel private[ml] ( @Since("2.0.0") override def copy(extra: ParamMap): GeneralizedLinearRegressionModel = { - copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), extra) - .setParent(parent) + val copied = copyValues(new GeneralizedLinearRegressionModel(uid, coefficients, intercept), + extra) + if (trainingSummary.isDefined) copied.setSummary(trainingSummary.get) + copied.setParent(parent) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index 0fdba1cb88..5d1a39f7c1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -221,7 +221,7 @@ class TrainValidationSplitModel private[ml] ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], validationMetrics.clone()) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } @Since("2.0.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 8771fd2e9d..2877285eb4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, SparseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext @@ -141,6 +141,12 @@ class LogisticRegressionSuite assert(model.getProbabilityCol === "probability") assert(model.intercept !== 0.0) assert(model.hasParent) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("empty probabilityCol") { @@ -251,9 +257,6 @@ class LogisticRegressionSuite mlr.setFitIntercept(false) val mlrModel = mlr.fit(smallMultinomialDataset) assert(mlrModel.interceptVector === Vectors.sparse(3, Seq())) - - // copied model must have the same parent. - MLTestingUtils.checkCopy(model) } test("logistic regression with setters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index f2368a9f8d..49797d938d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -41,6 +42,13 @@ class BisectingKMeansSuite assert(bkm.getPredictionCol === "prediction") assert(bkm.getMaxIter === 20) assert(bkm.getMinDivisibleClusterSize === 1.0) + val model = bkm.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("setter/getter") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index 003fa6abf6..7165b63ed3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -18,7 +18,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Dataset @@ -43,6 +44,13 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(gm.getPredictionCol === "prediction") assert(gm.getMaxIter === 100) assert(gm.getTol === 0.01) + val model = gm.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("set parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index ca39265355..73972557d2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -19,7 +19,8 @@ package org.apache.spark.ml.clustering import org.apache.spark.SparkFunSuite import org.apache.spark.ml.linalg.{Vector, Vectors} -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} @@ -47,6 +48,13 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL) assert(kmeans.getInitSteps === 2) assert(kmeans.getTol === 1e-4) + val model = kmeans.setMaxIter(1).fit(dataset) + + // copied model must have the same parent + MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) } test("set parameters") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index ac1ef5feb9..111bc97464 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{BLAS, DenseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random._ @@ -183,6 +183,9 @@ class GeneralizedLinearRegressionSuite // copied model must have the same parent. MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) assert(model.getFeaturesCol === "features") assert(model.getPredictionCol === "prediction") diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index c0e8afbf5e..df97d0b2ae 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors} -import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} @@ -143,6 +143,9 @@ class LinearRegressionSuite // copied model must have the same parent. MLTestingUtils.checkCopy(model) + assert(model.hasSummary) + val copiedModel = model.copy(ParamMap.empty) + assert(copiedModel.hasSummary) model.transform(datasetWithDenseFeature) .select("label", "prediction") diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 87100ae2e3..4463a9b6e5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -22,11 +22,11 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.linalg.{DenseMatrix, Vectors} +import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.sql.Dataset import org.apache.spark.sql.types.StructType @@ -78,6 +78,10 @@ class TrainValidationSplitSuite .setTrainRatio(0.5) .setSeed(42L) val cvModel = cv.fit(dataset) + + // copied model must have the same paren. + MLTestingUtils.checkCopy(cvModel) + val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) |