From 2932e25da4532de9e86b01d08bce0cb680874e70 Mon Sep 17 00:00:00 2001 From: lewuathe Date: Thu, 13 Aug 2015 09:17:19 -0700 Subject: [SPARK-9073] [ML] spark.ml Models copy() should call setParent when there is a parent Copied ML models must have the same parent of original ones Author: lewuathe Author: Lewuathe Closes #7447 from Lewuathe/SPARK-9073. --- mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 2 +- .../org/apache/spark/ml/classification/DecisionTreeClassifier.scala | 1 + .../main/scala/org/apache/spark/ml/classification/GBTClassifier.scala | 2 +- .../scala/org/apache/spark/ml/classification/LogisticRegression.scala | 2 +- .../src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala | 2 +- .../org/apache/spark/ml/classification/RandomForestClassifier.scala | 1 + mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala | 4 +++- mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala | 2 +- .../scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala | 2 +- .../src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala | 2 +- .../main/scala/org/apache/spark/ml/regression/LinearRegression.scala | 2 +- .../scala/org/apache/spark/ml/regression/RandomForestRegressor.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 2 +- 20 files changed, 22 insertions(+), 18 deletions(-) (limited to 'mllib/src/main/scala/org') diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index aef2c019d2..a3e59401c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -198,6 +198,6 @@ class PipelineModel private[ml] ( } override def copy(extra: ParamMap): PipelineModel = { - new PipelineModel(uid, stages.map(_.copy(extra))) + new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 29598f3f05..6f70b96b17 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -141,6 +141,7 @@ final class DecisionTreeClassificationModel private[ml] ( override def copy(extra: ParamMap): DecisionTreeClassificationModel = { copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) + .setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index c3891a9599..3073a2a61c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -196,7 +196,7 @@ final class GBTClassificationModel( } override def copy(extra: ParamMap): GBTClassificationModel = { - copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 5bcd7117b6..21fbe38ca8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -468,7 +468,7 @@ class LogisticRegressionModel private[ml] ( } override def copy(extra: ParamMap): LogisticRegressionModel = { - copyValues(new LogisticRegressionModel(uid, weights, intercept), extra) + copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent) } override protected def raw2prediction(rawPrediction: Vector): Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1741f19dc9..1132d8046d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -138,7 +138,7 @@ final class OneVsRestModel private[ml] ( override def copy(extra: ParamMap): OneVsRestModel = { val copied = new OneVsRestModel( uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]])) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 156050aaf7..11a6d72468 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -189,6 +189,7 @@ final class RandomForestClassificationModel private[ml] ( override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) + .setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 67e4785bc3..cfca494dcf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -90,7 +90,9 @@ final class Bucketizer(override val uid: String) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } - override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra) + override def copy(extra: ParamMap): Bucketizer = { + defaultCopy[Bucketizer](extra).setParent(parent) + } } private[feature] object Bucketizer { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index ecde808105..938447447a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -114,6 +114,6 @@ class IDFModel private[ml] ( override def copy(extra: ParamMap): IDFModel = { val copied = new IDFModel(uid, idfModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 9a473dd237..1b494ec8b1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -173,6 +173,6 @@ class MinMaxScalerModel private[ml] ( override def copy(extra: ParamMap): MinMaxScalerModel = { val copied = new MinMaxScalerModel(uid, originalMin, originalMax) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 2d3bb680cf..539084704b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -125,6 +125,6 @@ class PCAModel private[ml] ( override def copy(extra: ParamMap): PCAModel = { val copied = new PCAModel(uid, pcaModel) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72b545e5db..f6d0b0c0e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -136,6 +136,6 @@ class StandardScalerModel private[ml] ( override def copy(extra: ParamMap): StandardScalerModel = { val copied = new StandardScalerModel(uid, scaler) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index e4485eb038..9e4b0f0add 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -168,7 +168,7 @@ class StringIndexerModel private[ml] ( override def copy(extra: ParamMap): StringIndexerModel = { val copied = new StringIndexerModel(uid, labels) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index c73bdccdef..6875aefe06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -405,6 +405,6 @@ class VectorIndexerModel private[ml] ( override def copy(extra: ParamMap): VectorIndexerModel = { val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 29acc3eb58..5af775a415 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -221,6 +221,6 @@ class Word2VecModel private[ml] ( override def copy(extra: ParamMap): Word2VecModel = { val copied = new Word2VecModel(uid, wordVectors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2e44cd4cc6..7db8ad8d27 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -219,7 +219,7 @@ class ALSModel private[ml] ( override def copy(extra: ParamMap): ALSModel = { val copied = new ALSModel(uid, rank, userFactors, itemFactors) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index dc94a14014..a2bcd67401 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -114,7 +114,7 @@ final class DecisionTreeRegressionModel private[ml] ( } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { - copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra) + copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index 5633bc3202..b66e61f37d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -185,7 +185,7 @@ final class GBTRegressionModel( } override def copy(extra: ParamMap): GBTRegressionModel = { - copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra) + copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 92d819bad8..884003eb38 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -312,7 +312,7 @@ class LinearRegressionModel private[ml] ( override def copy(extra: ParamMap): LinearRegressionModel = { val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept)) if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get) - newModel + newModel.setParent(parent) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index db75c0d263..2f36da371f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -151,7 +151,7 @@ final class RandomForestRegressionModel private[ml] ( } override def copy(extra: ParamMap): RandomForestRegressionModel = { - copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra) + copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) } override def toString: String = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index f979319cc4..4792eb0f0a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,6 +160,6 @@ class CrossValidatorModel private[ml] ( uid, bestModel.copy(extra).asInstanceOf[Model[_]], avgMetrics.clone()) - copyValues(copied, extra) + copyValues(copied, extra).setParent(parent) } } -- cgit v1.2.3