aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-08-13 09:17:19 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-13 09:17:19 -0700
commit2932e25da4532de9e86b01d08bce0cb680874e70 (patch)
tree07acdab578dc3c7a543babfcef1a6e3a8227807d /mllib/src/main
parent69930310115501f0de094fe6f5c6c60dade342bd (diff)
downloadspark-2932e25da4532de9e86b01d08bce0cb680874e70.tar.gz
spark-2932e25da4532de9e86b01d08bce0cb680874e70.tar.bz2
spark-2932e25da4532de9e86b01d08bce0cb680874e70.zip
[SPARK-9073] [ML] spark.ml Models copy() should call setParent when there is a parent
Copied ML models must have the same parent of original ones Author: lewuathe <lewuathe@me.com> Author: Lewuathe <lewuathe@me.com> Closes #7447 from Lewuathe/SPARK-9073.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala2
20 files changed, 22 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index aef2c019d2..a3e59401c5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -198,6 +198,6 @@ class PipelineModel private[ml] (
}
override def copy(extra: ParamMap): PipelineModel = {
- new PipelineModel(uid, stages.map(_.copy(extra)))
+ new PipelineModel(uid, stages.map(_.copy(extra))).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 29598f3f05..6f70b96b17 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -141,6 +141,7 @@ final class DecisionTreeClassificationModel private[ml] (
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
+ .setParent(parent)
}
override def toString: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index c3891a9599..3073a2a61c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -196,7 +196,7 @@ final class GBTClassificationModel(
}
override def copy(extra: ParamMap): GBTClassificationModel = {
- copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra)
+ copyValues(new GBTClassificationModel(uid, _trees, _treeWeights), extra).setParent(parent)
}
override def toString: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 5bcd7117b6..21fbe38ca8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -468,7 +468,7 @@ class LogisticRegressionModel private[ml] (
}
override def copy(extra: ParamMap): LogisticRegressionModel = {
- copyValues(new LogisticRegressionModel(uid, weights, intercept), extra)
+ copyValues(new LogisticRegressionModel(uid, weights, intercept), extra).setParent(parent)
}
override protected def raw2prediction(rawPrediction: Vector): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 1741f19dc9..1132d8046d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -138,7 +138,7 @@ final class OneVsRestModel private[ml] (
override def copy(extra: ParamMap): OneVsRestModel = {
val copied = new OneVsRestModel(
uid, labelMetadata, models.map(_.copy(extra).asInstanceOf[ClassificationModel[_, _]]))
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index 156050aaf7..11a6d72468 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -189,6 +189,7 @@ final class RandomForestClassificationModel private[ml] (
override def copy(extra: ParamMap): RandomForestClassificationModel = {
copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra)
+ .setParent(parent)
}
override def toString: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 67e4785bc3..cfca494dcf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -90,7 +90,9 @@ final class Bucketizer(override val uid: String)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
- override def copy(extra: ParamMap): Bucketizer = defaultCopy(extra)
+ override def copy(extra: ParamMap): Bucketizer = {
+ defaultCopy[Bucketizer](extra).setParent(parent)
+ }
}
private[feature] object Bucketizer {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index ecde808105..938447447a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -114,6 +114,6 @@ class IDFModel private[ml] (
override def copy(extra: ParamMap): IDFModel = {
val copied = new IDFModel(uid, idfModel)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 9a473dd237..1b494ec8b1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -173,6 +173,6 @@ class MinMaxScalerModel private[ml] (
override def copy(extra: ParamMap): MinMaxScalerModel = {
val copied = new MinMaxScalerModel(uid, originalMin, originalMax)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 2d3bb680cf..539084704b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -125,6 +125,6 @@ class PCAModel private[ml] (
override def copy(extra: ParamMap): PCAModel = {
val copied = new PCAModel(uid, pcaModel)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 72b545e5db..f6d0b0c0e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -136,6 +136,6 @@ class StandardScalerModel private[ml] (
override def copy(extra: ParamMap): StandardScalerModel = {
val copied = new StandardScalerModel(uid, scaler)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index e4485eb038..9e4b0f0add 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -168,7 +168,7 @@ class StringIndexerModel private[ml] (
override def copy(extra: ParamMap): StringIndexerModel = {
val copied = new StringIndexerModel(uid, labels)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index c73bdccdef..6875aefe06 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -405,6 +405,6 @@ class VectorIndexerModel private[ml] (
override def copy(extra: ParamMap): VectorIndexerModel = {
val copied = new VectorIndexerModel(uid, numFeatures, categoryMaps)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 29acc3eb58..5af775a415 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -221,6 +221,6 @@ class Word2VecModel private[ml] (
override def copy(extra: ParamMap): Word2VecModel = {
val copied = new Word2VecModel(uid, wordVectors)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 2e44cd4cc6..7db8ad8d27 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -219,7 +219,7 @@ class ALSModel private[ml] (
override def copy(extra: ParamMap): ALSModel = {
val copied = new ALSModel(uid, rank, userFactors, itemFactors)
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index dc94a14014..a2bcd67401 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -114,7 +114,7 @@ final class DecisionTreeRegressionModel private[ml] (
}
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
- copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra)
+ copyValues(new DecisionTreeRegressionModel(uid, rootNode), extra).setParent(parent)
}
override def toString: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index 5633bc3202..b66e61f37d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -185,7 +185,7 @@ final class GBTRegressionModel(
}
override def copy(extra: ParamMap): GBTRegressionModel = {
- copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra)
+ copyValues(new GBTRegressionModel(uid, _trees, _treeWeights), extra).setParent(parent)
}
override def toString: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 92d819bad8..884003eb38 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -312,7 +312,7 @@ class LinearRegressionModel private[ml] (
override def copy(extra: ParamMap): LinearRegressionModel = {
val newModel = copyValues(new LinearRegressionModel(uid, weights, intercept))
if (trainingSummary.isDefined) newModel.setSummary(trainingSummary.get)
- newModel
+ newModel.setParent(parent)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index db75c0d263..2f36da371f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -151,7 +151,7 @@ final class RandomForestRegressionModel private[ml] (
}
override def copy(extra: ParamMap): RandomForestRegressionModel = {
- copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra)
+ copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent)
}
override def toString: String = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index f979319cc4..4792eb0f0a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -160,6 +160,6 @@ class CrossValidatorModel private[ml] (
uid,
bestModel.copy(extra).asInstanceOf[Model[_]],
avgMetrics.clone())
- copyValues(copied, extra)
+ copyValues(copied, extra).setParent(parent)
}
}